Source code for temgymbasic.run

from functools import partial
from temgymbasic.functions import get_image_from_rays

import PyQt5
from PyQt5.QtWidgets import QMainWindow
from PyQt5.QtWidgets import QVBoxLayout
from PyQt5.QtWidgets import QWidget
from PyQt5.QtWidgets import QScrollArea

import pyqtgraph.opengl as gl
import pyqtgraph as pg
from pyqtgraph.Qt import QtCore
from pyqtgraph.dockarea import Dock, DockArea

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rcParams['font.family'] = 'Helvetica'
mpl.rc('axes', titlesize=32, labelsize=28)

__version__ = "0.5"
__author__ = "David Landers"

#Create the UI class
[docs]class LinearTEMUi(QMainWindow): '''Create the UI Window Parameters ---------- QMainWindow : class Pyqt5's Main window Class ''' """LinearTEM's Viewer (GUI)."""
[docs] def __init__(self, model): '''Init important parameters Parameters ---------- model : class Microscope model ''' """View initializer.""" super().__init__() #%%% Define Camera Parameters self.initial_camera_params = {'center': PyQt5.QtGui.QVector3D(-0.5, -0.5, 0.0), 'fov': 25, 'azimuth': 45.0, 'distance': 10, 'elevation': 25.0, } self.x_camera_params = {'center': PyQt5.QtGui.QVector3D(0.0, 0.0, 0.5), 'fov': 7e-07, 'azimuth': 90.0, 'distance': 143358760, 'elevation': 0.0} self.y_camera_params = {'center': PyQt5.QtGui.QVector3D(0.0, 0.0, 0.5), 'fov': 7e-07, 'azimuth': 0, 'distance': 143358760, 'elevation': 0.0} self.model = model # Set some main window's properties self.setWindowTitle("LinearTEM") self.resize(1920, 1080) self.centralWidget = DockArea() self.setCentralWidget(self.centralWidget) # Create Docks self.tem_dock = Dock("3D View") self.detector_dock = Dock("Detector", size=(5, 5)) self.gui_dock = Dock("GUI", size=(10, 3)) self.centralWidget.addDock(self.tem_dock, "left") self.centralWidget.addDock(self.detector_dock, "bottom", self.tem_dock) self.centralWidget.addDock(self.gui_dock, "right") #create detector scale = self.model.detector_size/2 vertices = np.array([[1, 1, 0], [-1, 1, 0], [-1, -1, 0], [1, -1, 0], [1, 1, 0]]) * scale self.detector_outline = gl.GLLinePlotItem(pos=vertices, color="w", mode='line_strip') # Create the display and the buttons self.create3DDisplay() self.createDetectorDisplay() self.createGUI()
[docs] def create3DDisplay(self): '''Create the 3D Display ''' # Create the 3D TEM Widnow, and plot the components in 3D self.tem_window = gl.GLViewWidget() #Make an axis and addit to the 3D window. Also set up the ray geometry placeholder #and detector outline. axis = gl.GLAxisItem() self.tem_window.addItem(axis) self.tem_window.setBackgroundColor((150, 150, 150, 255)) self.tem_window.setCameraPosition(distance=5) self.ray_geometry = gl.GLLinePlotItem(mode='lines', width=2) self.tem_window.addItem(self.ray_geometry) self.tem_window.addItem(self.detector_outline) self.tem_window.setCameraParams(**self.initial_camera_params) #Loop through all of the model components, and add their geometry to the TEM window. for component in self.model.components: for geometry in component.gl_points: self.tem_window.addItem(geometry) self.tem_window.addItem(component.label) #Add the ray geometry GLLinePlotItem to the list of geometries for that window self.tem_window.addItem(self.ray_geometry) #Add the window to the dock self.tem_dock.addWidget(self.tem_window)
[docs] def createDetectorDisplay(self): '''Create the detector display ''' #Create the detector window, which shows where rays land at the bottom self.detector_window = pg.GraphicsLayoutWidget() self.detector_window.setAspectLocked(1.0) self.spot_img = pg.ImageItem(border="b") v2 = self.detector_window.addViewBox() v2.setAspectLocked() #Invert coordinate system so spot moves up when it should v2.invertY() v2.addItem(self.spot_img) self.detector_dock.addWidget(self.detector_window)
[docs] def createGUI(self): '''Create the gui display ''' #Create the window which houses the GUI scroll = QScrollArea() scroll.setWidgetResizable(1) content = QWidget() scroll.setWidget(content) self.layout = QVBoxLayout(content) self.gui_dock.addWidget(scroll, 1, 0) self.model.create_gui() self.layout.addWidget(self.model.gui.box, 0) #Loop through all components, and display the GUI for each for idx, component in enumerate(self.model.components, start = 1): component.create_gui() self.layout.addWidget(component.gui.box, idx)
# Create a Controller class to connect the GUI and the model
[docs]class LinearTEMCtrl: '''Control code which links the model and 3D viewer '''
[docs] def __init__(self, model, view): ''' Parameters ---------- model : class Microscope model view : class UI Viewer ''' self.model = model self.view = view #Create a timer self.timer = QtCore.QTimer() self.timer.timeout.connect(self.update) self.timer.setInterval(10) # Connect signals and slots self.connectSignals() self.update()
[docs] def timerstart(self, btn, component): '''Start a timer Parameters ---------- btn : PyQt5 Button '' component : class Check which component GUI has clicked the button, so we know what timer to start ''' # checking if button state is checked if btn.isChecked() == True: if component.type == 'Double Deflector': # if first check box is selected if btn == component.gui.xbuttonwobble: self.timer.start() # making other check box to uncheck component.gui.ybuttonwobble.setChecked(False) # if second check box is selected elif btn == component.gui.ybuttonwobble: self.timer.start() # making other check box to uncheck component.gui.xbuttonwobble.setChecked(False) elif component.type == 'Lens': self.timer.start() else: self.timer.stop()
[docs] def connectSignals(self): '''Connect the updates to the model to the GUI ''' self.model.gui.rayslider.valueChanged.connect(self.update) self.model.gui.checkBoxParalell.stateChanged.connect(self.update) self.model.gui.checkBoxPoint.stateChanged.connect(self.update) self.model.gui.checkBoxAxial.stateChanged.connect(self.update) self.model.gui.beamangleslider.valueChanged.connect(self.update) self.model.gui.beamwidthslider.valueChanged.connect(self.update) self.model.gui.init_button.clicked.connect(partial(self.set_camera_params, self.model.gui.init_button)) self.model.gui.x_button.clicked.connect(partial(self.set_camera_params, self.model.gui.x_button)) self.model.gui.y_button.clicked.connect(partial(self.set_camera_params, self.model.gui.y_button)) self.model.gui.xangleslider.valueChanged.connect(self.update) self.model.gui.yangleslider.valueChanged.connect(self.update) for component in self.model.components: if component.type == 'Lens': component.gui.fslider.valueChanged.connect(self.update) component.gui.fwobble.toggled.connect(partial(self.timerstart, component.gui.fwobble, component)) elif component.type == 'Deflector': component.gui.defxslider.valueChanged.connect(self.update) component.gui.defyslider.valueChanged.connect(self.update) elif component.type == 'Double Deflector': component.gui.updefxslider.valueChanged.connect(self.update) component.gui.updefyslider.valueChanged.connect(self.update) component.gui.lowdefxslider.valueChanged.connect(self.update) component.gui.lowdefyslider.valueChanged.connect(self.update) component.gui.defratioxslider.valueChanged.connect(self.update) component.gui.defratioyslider.valueChanged.connect(self.update) component.gui.xbuttonwobble.toggled.connect(partial(self.timerstart, component.gui.xbuttonwobble, component)) component.gui.ybuttonwobble.toggled.connect(partial(self.timerstart, component.gui.ybuttonwobble, component)) elif component.type == 'Biprism': component.gui.defslider.valueChanged.connect(self.update) component.gui.rotslider.valueChanged.connect(self.update) elif component.type == 'Aperture': component.gui.radiusslider.valueChanged.connect(self.update) component.gui.xslider.valueChanged.connect(self.update) component.gui.yslider.valueChanged.connect(self.update) elif component.type == 'Astigmatic Lens': component.gui.fxslider.valueChanged.connect(self.update) component.gui.fyslider.valueChanged.connect(self.update) elif component.type == 'Quadrupole': component.gui.fxslider.valueChanged.connect(self.update) component.gui.fyslider.valueChanged.connect(self.update) elif component.type == 'Sample': component.gui.xslider.valueChanged.connect(self.update) component.gui.yslider.valueChanged.connect(self.update)
[docs] def set_camera_params(self, btn): ''' Parameters ---------- btn : PyQt5 Button '' ''' if btn == self.model.gui.x_button: self.view.tem_window.setCameraParams(**self.view.x_camera_params) elif btn == self.model.gui.y_button: self.view.tem_window.setCameraParams(**self.view.y_camera_params) elif btn == self.model.gui.init_button: self.view.tem_window.setCameraParams(**self.view.initial_camera_params)
[docs] def update(self): '''Update the model ''' self.model.update_gui() #update components for component in self.model.components: component.update_gui() self.model.step() ray_z = np.tile(self.model.z_positions, [self.model.num_rays, 1, 1]).T # Stack with the z coordinates ray_xyz = np.hstack((self.model.r[:, [0, 2], :], ray_z)) # Repeat vertices so we can create lines. The shape of this array is [Num Steps*2, 3, Num Rays] lines_repeated = np.repeat(ray_xyz[:, :, :], repeats=2, axis=0)[1:-1] #create a range of numbers of the number of rays, which are initially the rays that are unblocked allowed_rays = range(self.model.num_rays) for component in self.model.components: if len(component.blocked_ray_idcs) != 0: #Find the difference between blocked rays and original amount of allowed rays allowed_rays = list(set(allowed_rays).difference(set(component.blocked_ray_idcs))) idx = component.index*2+2 #Get the coordinates of all rays which hit the aperture. pts_blocked = lines_repeated[idx, :, component.blocked_ray_idcs] #Do really funky array manipulation to create a copy of all of these points that is the same shape as the vertices of remaining lines #after the aperture lines_aperture = np.broadcast_to(pts_blocked[...,None], pts_blocked.shape+(lines_repeated.shape[0]-(idx),)).transpose(2, 1, 0) #Copy the coordinate of all rays that hit the aperture, to all line vertices after this, so we don't visualise them. lines_repeated[idx:, :, component.blocked_ray_idcs] = lines_aperture # Then restack each line so that we end up with a long list of lines, from [Num Steps*2, 3, Num Rays] > [(Num Steps*2-2)*Num rays, 3] # see > https://stackoverflow.com/questions/38509196/efficiently-re-stacking-a-numpy-ndarray lines_paired = lines_repeated.transpose(2, 0, 1).reshape( lines_repeated.shape[0]*self.model.num_rays, 3) #Create detector image detector_image, _ = get_image_from_rays( self.model.r[-1, 2, allowed_rays], self.model.r[-1, 0, allowed_rays], self.model.detector_size, self.model.detector_pixels) #Update the spot image and the rays of the viewer self.view.spot_img.setImage(detector_image.T) self.view.ray_geometry.setData(pos=lines_paired, color=(0, 0.8, 0, 0.05))
[docs]def run_pyqt(model): '''Main code to run a pyqt model Parameters ---------- model : class Microscope Model ''' #Generate the GUI viewer = LinearTEMUi(model) #Connect the model with the viewer LinearTEMCtrl(model, viewer) return viewer
#Example code to make a matplotlib plot
[docs]def show_matplotlib(model, name = 'model.svg', component_lw = 4, edge_lw = 1, label_fontsize = 20): '''Code to show a matplotlib model Parameters ---------- model : class Microscope Model name : str, optional Name of file, by default 'model.svg' component_lw : int, optional Linewidth of component outline, by default 4 edge_lw : int, optional Linewidth of highlight to edges, by default 1 label_fontsize : int, optional Fontsize of labels, by default 20 Returns ------- fig : class Matplotlib figure object ax : class Matplotlib axis object of the figure ''' #Step the rays through the model to get the ray positions throughout the column rays = model.step() #Collect their x, y & z coordinates x, y, z = rays[:, 0, :], rays[:, 2, :], model.z_positions #Create a figure fig, ax = plt.subplots(figsize=(12, 20)) ax.tick_params(axis='both', which='major', labelsize=14) ax.tick_params(axis='both', which='minor', labelsize=12) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) ax.grid(color='lightgrey', linestyle='--', linewidth=0.5) ax.grid(which='minor', color='#EEEEEE', linestyle=':', linewidth=0.5) ax.set_yticks([]) ax.set_yticklabels([]) ax.get_xaxis().set_ticks( [-model.detector_size/2, 0, model.detector_size/2]) ax.set_xlim([-0.5, 0.5]) ax.set_ylim([0, model.beam_z]) ax.set_aspect('equal', adjustable='box') ax.text(0, model.beam_z, 'Electron Gun', fontsize=label_fontsize, zorder = 1000) #Set starting index of component so that we can plot rays from one component to the next idx = 1 #Generate a list of the allowed rays, so we can block them when they hit an aperture allowed_rays = range(model.num_rays) #Set colors of rays ray_color = 'dimgray' fill_color = 'aquamarine' fill_color_pair = ['khaki', 'deepskyblue'] fill_alpha = 1 ray_alpha = 1 ray_lw = 0.25 plot_rays = True highlight_edges = True fill_between = True edge_rays = [0, model.num_rays-1] label_x = 0.30 #Loop through components, and for each type of component plot rays in the correct ray, #and increment the index correctly for component in model.components: if allowed_rays != []: if highlight_edges == True: ax.plot(x[idx-1:idx+1, edge_rays], z[idx-1:idx+1], color='k', linewidth=edge_lw, alpha=1, zorder=2) if fill_between == True: pair_idx = 0 for first, second in zip(edge_rays[::2], edge_rays[1::2]): if len(edge_rays) == 4: ax.fill_betweenx(z[idx-1:idx+1], x[idx-1:idx+1, first], x[idx-1:idx+1, second], color=fill_color_pair[pair_idx], edgecolor=fill_color_pair[pair_idx], alpha=fill_alpha, zorder=0, lw=None) pair_idx += 1 else: ax.fill_betweenx(z[idx-1:idx+1], x[idx-1:idx+1, first], x[idx-1:idx+1, second], color=fill_color, edgecolor=fill_color, alpha=fill_alpha, zorder=0, lw=None) if plot_rays == True: ax.plot(x[idx-1:idx+1, allowed_rays], z[idx-1:idx+1], color=ray_color, linewidth=ray_lw, alpha=ray_alpha, zorder=1) if component.type == 'Biprism': ax.text(label_x, component.z-0.01, component.name, fontsize=label_fontsize, zorder = 1000) if model.beam_type == 'x_axial' and component.theta == 0: ax.plot(component.points[0, :], component.points[2, :], color='dimgrey', alpha=0.8, linewidth=component_lw) elif model.beam_type == 'x_axial' and component.theta == np.pi/2: ax.add_patch(plt.Circle((0, component.z), component.width, edgecolor='k', facecolor='w', zorder=1000)) idx += 1 elif component.type == 'Quadrupole': r = component.radius ax.text(label_x, component.z-0.01, 'Upper ' + component.name, fontsize=label_fontsize, zorder = 1000) ax.plot([-r, -r/2], [z[idx], z[idx]], color='lightcoral', alpha=1, linewidth=component_lw, zorder=999) ax.plot([-r/2, 0], [z[idx], z[idx]], color='lightblue', alpha=1, linewidth=component_lw, zorder=999) ax.plot([0, r/2], [z[idx], z[idx]], color='lightcoral', alpha=1, linewidth=component_lw, zorder=999) ax.plot([r/2, r], [z[idx], z[idx]], color='lightblue', alpha=1, linewidth=component_lw, zorder=999) ax.plot([-r, r], [z[idx], z[idx]], color='k', alpha=0.8, linewidth=component_lw+2, zorder=998) idx += 1 elif component.type == 'Aperture': ax.text(label_x, component.z-0.01, component.name, fontsize=label_fontsize, zorder = 1000) ri = component.aperture_radius_inner ro = component.aperture_radius_outer ax.plot([-ri, -ro], [z[idx], z[idx]], color='dimgrey', alpha=1, linewidth=component_lw, zorder=999) ax.plot([ri, ro], [z[idx], z[idx]], color='dimgrey', alpha=1, linewidth=component_lw, zorder=999) ax.plot([-ri, -ro], [z[idx], z[idx]], color='k', alpha=1, linewidth=component_lw+2, zorder=998) ax.plot([ri, ro], [z[idx], z[idx]], color='k', alpha=1, linewidth=component_lw+2, zorder=998) idx += 1 elif component.type == 'Double Deflector': r = component.radius ax.text(label_x, component.z_up-0.01, 'Upper ' + component.name, fontsize=label_fontsize, zorder = 1000) ax.plot([-r, 0], [z[idx], z[idx]], color='lightcoral', alpha=1, linewidth=component_lw, zorder=999) ax.plot([0, r], [z[idx], z[idx]], color='lightblue', alpha=1, linewidth=component_lw, zorder=999) ax.plot([-r, r], [z[idx], z[idx]], color='k', alpha=0.8, linewidth=component_lw+2, zorder=998) idx += 1 if allowed_rays != []: if highlight_edges == True: ax.plot(x[idx-1:idx+1, edge_rays], z[idx-1:idx+1], color='k', linewidth=edge_lw, alpha=1, zorder=2) if fill_between == True: pair_idx = 0 for first, second in zip(edge_rays[::2], edge_rays[1::2]): if len(edge_rays) == 4: ax.fill_betweenx(z[idx-1:idx+1], x[idx-1:idx+1, first], x[idx-1:idx+1, second], color=fill_color_pair[pair_idx], alpha=fill_alpha, zorder=1) pair_idx += 1 else: ax.fill_betweenx(z[idx-1:idx+1], x[idx-1:idx+1, first], x[idx - 1:idx+1, second], color=fill_color, alpha=fill_alpha, zorder=0) if plot_rays == True: ax.plot(x[idx-1:idx+1, allowed_rays], z[idx-1:idx+1], color=ray_color, linewidth=ray_lw, alpha=ray_alpha, zorder=1) ax.text(label_x, component.z_low-0.01, 'Lower ' + component.name, fontsize=label_fontsize, zorder = 1000) ax.plot([-r, 0], [z[idx], z[idx]], color='lightcoral', alpha=1, linewidth=component_lw, zorder=999) ax.plot([0, r], [z[idx], z[idx]], color='lightblue', alpha=1, linewidth=component_lw, zorder=999) ax.plot([-r, r], [z[idx], z[idx]], color='k', alpha=0.8, linewidth=component_lw+2, zorder=998) idx += 1 elif component.type == 'Lens': ax.text(label_x, component.z-0.01, component.name, fontsize=label_fontsize, zorder = 1000) ax.add_patch(mpl.patches.Arc((0, component.z), component.radius*2, height=0.05, theta1=0, theta2=180, linewidth=1, fill=False, zorder=-1, edgecolor='k')) ax.add_patch(mpl.patches.Arc((0, component.z), component.radius*2, height=0.05, theta1=180, theta2=0, linewidth=1, fill=False, zorder=999, edgecolor='k')) idx += 1 elif component.type == 'Astigmatic Lens': ax.text(label_x, component.z-0.01, component.name, fontsize=label_fontsize, zorder = 1000) ax.add_patch(mpl.patches.Arc((0, component.z), component.radius*2, height=0.05, theta1=0, theta2=180, linewidth=1, fill=False, zorder=-1, edgecolor='k')) ax.add_patch(mpl.patches.Arc((0, component.z), component.radius*2, height=0.05, theta1=180, theta2=0, linewidth=1, fill=False, zorder=999, edgecolor='k')) idx += 1 elif component.type == 'Deflector': r = component.radius ax.text(label_x, component.z-0.01, component.name, fontsize=label_fontsize, zorder = 1000) ax.plot([-r, 0], [z[idx], z[idx]], color='lightcoral', alpha=1, linewidth=component_lw, zorder=999) ax.plot([0, r], [z[idx], z[idx]], color='lightblue', alpha=1, linewidth=component_lw, zorder=999) ax.plot([-r, r], [z[idx], z[idx]], color='k', alpha=0.8, linewidth=component_lw+2, zorder=998) idx += 1 elif component.type == 'Sample': ax.text(label_x, component.z-0.01, component.name, fontsize=label_fontsize, zorder = 1000) w = component.width ax.plot([component.x-w/2, component.x+w/2], [z[idx], z[idx]], color='dimgrey', alpha=0.8, linewidth=3) idx += 1 allowed_rays = list(set(allowed_rays).difference( set(component.blocked_ray_idcs))) allowed_rays.sort() if len(allowed_rays) > 0: edge_rays = [allowed_rays[0], allowed_rays[-1]] new_edges = np.where(np.diff(allowed_rays) != 1)[0].tolist() for new_edge in new_edges: edge_rays.extend( [allowed_rays[new_edge], allowed_rays[new_edge+1]]) edge_rays.sort() else: break #We need to repeat the code once more for the rays at the end if allowed_rays != []: if highlight_edges == True: ax.plot(x[idx-1:idx+1, edge_rays], z[idx-1:idx+1], color='k', linewidth=edge_lw, alpha=1, zorder=2) if fill_between == True: pair_idx = 0 for first, second in zip(edge_rays[::2], edge_rays[1::2]): if len(edge_rays) == 4: ax.fill_betweenx(z[idx-1:idx+1], x[idx-1:idx+1, first], x[idx-1:idx+1, second], color=fill_color_pair[pair_idx], edgecolor=fill_color_pair[pair_idx], alpha=fill_alpha, zorder=1) pair_idx += 1 else: ax.fill_betweenx(z[idx-1:idx+1], x[idx-1:idx+1, first], x[idx-1:idx+1, second], color=fill_color, edgecolor=fill_color, alpha=fill_alpha, zorder=0) if plot_rays == True: ax.plot(x[idx-1:idx+1, allowed_rays], z[idx-1:idx+1], color=ray_color, linewidth=ray_lw, alpha=ray_alpha, zorder=1) #Create the final labels and plot the detector shape ax.text(label_x, -0.01, 'Detector', fontsize=label_fontsize, zorder = 1000) ax.plot([-model.detector_size/2, model.detector_size/2], [0, 0], color='dimgrey', alpha=1, linewidth=component_lw) return fig, ax