import clr, os, winreg
from itertools import islice

import matplotlib.pyplot as plt
import numpy as np

# This boilerplate requires the 'pythonnet' module.
# The following instructions are for installing the 'pythonnet' module via pip:
#    1. Ensure you are running Python 3.4, 3.5, 3.6, or 3.7. PythonNET does not work with Python 3.8 yet.
#    2. Install 'pythonnet' from pip via a command prompt
#    (type 'cmd' from the start menu or press Windows + R and type 'cmd' then enter)
#
#        python -m pip install pythonnet

class PythonUserExtension(object):
    class LicenseException(Exception):
        pass

    class ConnectionException(Exception):
        pass

    class InitializationException(Exception):
        pass

    class SystemNotPresentException(Exception):
        pass

    def __init__(self, path=None):
        # determine location of ZOSAPI_NetHelper.dll & add as reference
        a_key = winreg.OpenKey(winreg.ConnectRegistry(None, winreg.HKEY_CURRENT_USER), r"Software\Zemax", 0,
                               winreg.KEY_READ)
        zemax_data = winreg.QueryValueEx(a_key, 'ZemaxRoot')
        net_helper = os.path.join(os.sep, zemax_data[0], r'ZOS-API\Libraries\ZOSAPI_NetHelper.dll')
        winreg.CloseKey(a_key)
        clr.AddReference(net_helper)
        import ZOSAPI_NetHelper

        # Find the installed version of OpticStudio
        # if len(path) == 0:
        if path is None:
            is_initialized = ZOSAPI_NetHelper.ZOSAPI_Initializer.Initialize()
        else:
            # Note -- uncomment the following line to use a custom initialization path
            is_initialized = ZOSAPI_NetHelper.ZOSAPI_Initializer.Initialize(path)

        # determine the ZOS root directory
        if is_initialized:
            zemax_dir = ZOSAPI_NetHelper.ZOSAPI_Initializer.GetZemaxDirectory()
        else:
            raise PythonUserExtension.InitializationException(
                "Unable to locate Zemax OpticStudio.  Try using a hard-coded path.")

        # add ZOS-API references
        clr.AddReference(os.path.join(os.sep, zemax_dir, "ZOSAPI.dll"))
        clr.AddReference(os.path.join(os.sep, zemax_dir, "ZOSAPI_Interfaces.dll"))
        import ZOSAPI

        # create a reference to the API namespace
        self.ZOSAPI = ZOSAPI

        # create a reference to the API namespace
        self.ZOSAPI = ZOSAPI

        # Create the initial connection class
        self.TheConnection = ZOSAPI.ZOSAPI_Connection()

        if self.TheConnection is None:
            raise PythonUserExtension.ConnectionException("Unable to initialize .NET connection to ZOSAPI")

        self.TheApplication = self.TheConnection.ConnectToApplication()
        if self.TheApplication is None:
            raise PythonUserExtension.InitializationException("Unable to acquire ZOSAPI application")

        if self.TheApplication.Mode != ZOSAPI.ZOSAPI_Mode.Plugin:
            raise PythonUserExtension. \
                InitializationException("User plugin was started in the wrong mode: expected Plugin, found ",
                                        ZOSAPI.ZOSAPI_Mode.GetName(ZOSAPI.ZOSAPI_Mode, self.TheApplication.Mode))

        if not self.TheApplication.IsValidLicenseForAPI:
            raise PythonUserExtension.LicenseException("License is not valid for ZOSAPI use")

        self.TheSystem = self.TheApplication.PrimarySystem
        if self.TheSystem is None:
            raise PythonUserExtension.SystemNotPresentException("Unable to acquire Primary system")

        self.TheApplication.ProgressPercent = 0
        self.TheApplication.ProgressMessage = 'Running Extension...'

    def __del__(self):
        if self.TheApplication is not None:
            self.TheApplication.ProgressMessage = 'Complete'
            self.TheApplication.ProgressPercent = 100

    def OpenFile(self, filepath, save_if_needed):
        if self.TheSystem is None:
            raise PythonUserExtension.SystemNotPresentException("Unable to acquire Primary system")
        self.TheSystem.LoadFile(filepath, save_if_needed)

    def CloseFile(self, save):
        if self.TheSystem is None:
            raise PythonUserExtension.SystemNotPresentException("Unable to acquire Primary system")
        self.TheSystem.Close(save)

    def SamplesDir(self):
        if self.TheApplication is None:
            raise PythonUserExtension.InitializationException("Unable to acquire ZOSAPI application")

        return self.TheApplication.SamplesDir

    def ExampleConstants(self):
        if self.TheApplication.LicenseStatus == self.ZOSAPI.LicenseStatusType.PremiumEdition:
            return "Premium"
        elif self.TheApplication.LicenseStatus == self.ZOSAPI.LicenseStatusTypeProfessionalEdition:
            return "Professional"
        elif self.TheApplication.LicenseStatus == self.ZOSAPI.LicenseStatusTypeStandardEdition:
            return "Standard"
        else:
            return "Invalid"
            
    def reshape(self, data, x, y, transpose = False):
        """Converts a System.Double[,] to a 2D list for plotting or post processing
        
        Parameters
        ----------
        data      : System.Double[,] data directly from ZOS-API 
        x         : x width of new 2D list [use var.GetLength(0) for dimension]
        y         : y width of new 2D list [use var.GetLength(1) for dimension]
        transpose : transposes data; needed for some multi-dimensional line series data
        
        Returns
        -------
        res       : 2D list; can be directly used with Matplotlib or converted to
                    a numpy array using numpy.asarray(res)
        """
        if type(data) is not list:
            data = list(data)
        var_lst = [y] * x;
        it = iter(data)
        res = [list(islice(it, i)) for i in var_lst]
        if transpose:
            return self.transpose(res);
        return res
    
    def transpose(self, data):
        """Transposes a 2D list (Python3.x or greater).  
        
        Useful for converting mutli-dimensional line series (i.e. FFT PSF)
        
        Parameters
        ----------
        data      : Python native list (if using System.Data[,] object reshape first)    
        
        Returns
        -------
        res       : transposed 2D list
        """
        if type(data) is not list:
            data = list(data)
        return list(map(list, zip(*data)))

if __name__ == '__main__':
    zos = PythonUserExtension()
    
    #use http://matplotlib.org/ to plot 2D graph
    # need to install this package before running this code
    
    # load local variables
    ZOSAPI = zos.ZOSAPI
    TheApplication = zos.TheApplication
    TheSystem = zos.TheSystem
    
    # ! [e02s02_py]
    # Create ray trace
    NSCRayTrace = TheSystem.Tools.OpenNSCRayTrace()
    NSCRayTrace.SplitNSCRays = True
    NSCRayTrace.ScatterNSCRays = False
    NSCRayTrace.UsePolarization = True
    NSCRayTrace.IgnoreErrors = True
    NSCRayTrace.SaveRays = False
    NSCRayTrace.Run()
    # ! [e02s02_py]
    
    lastValue = []
    lastValue.append(0)
    print('Beginning ray trace:')
    while NSCRayTrace.IsRunning:
        currentValue = NSCRayTrace.Progress
        if currentValue % 2 == 0:
            if lastValue[len(lastValue) - 1] != currentValue:
                lastValue.append(currentValue)
                print(currentValue)
    NSCRayTrace.WaitForCompletion()
    NSCRayTrace.Close()
    
    # Non-sequential component editor
    TheNCE = TheSystem.NCE
    
    DetObj = 4
    obj = TheSystem.NCE.GetObjectAt(DetObj);
    numXPixels = obj.ObjectData.NumberXPixels;
    numYPixels = obj.ObjectData.NumberYPixels;
    pltWidth   = 2 * obj.ObjectData.XHalfWidth;
    pltHeight  = 2 * obj.ObjectData.YHalfWidth;

    pix = 0
    
    #! [e02s03_py]
    # Get detector data
    detectorData = [[0 for x in range(numYPixels)] for x in range(numXPixels)]
    for x in range(0,numYPixels,1):
        for y in range(0,numXPixels,1):
            ret, pixel_val = TheNCE.GetDetectorData(DetObj, pix, 1, 0)
            pix += 1
            if ret == 1:
                detectorData[y][x] = pixel_val
            else:
                detectorData[x][y] = -1
    #! [e02s03_py]
    
    # end of default code
    # everything below here is based on numpy/matplotlib and is not supported by ZOSAPI or Zemax
    # https://docs.scipy.org/doc/numpy/index.html
    # http://matplotlib.org/
    
    # This will clean up the connection to OpticStudio.
    # Note that it closes down the server instance of OpticStudio, so you for maximum performance do not do
    # this until you need to.
    
    detectorData = zos.transpose(detectorData)
    
    del zos
    zos = None
    
    # text output & FOR loops for OpticStudio will invert the vertical image
    # place plt.show() after clean up to release OpticStudio from memory
    plt.imshow(detectorData)
    plt.show()