import meep as mp
import cmath
import random
import argparse

def main(args):
    resolution = 100            # pixels/um
    
    lambda_min = 0.4            # minimum source wavelength
    lambda_max = 0.8            # maximum source wavelength
    fmin = 1/lambda_max         # minimum source frequency
    fmax = 1/lambda_min         # maximum source frequency
    fcen = 0.5*(fmin+fmax)      # source frequency center
    df = fmax-fmin              # source frequency width

    tABS = lambda_max           # absorber/PML thickness
    tGLS = 1.0                  # glass thickness
    tITO = 0.1                  # indium tin oxide thickness
    tORG = 0.1                  # organic thickness
    tAl = 0.1                   # aluminum thickness

    # length of computational cell along Z
    sz = tABS+tGLS+tITO+tORG+tAl
    # length of non-absorbing region of computational cell in X and Y
    L = args.L
    sxy = L+2*tABS    
    cell_size = mp.Vector3(sxy,sxy,sz)

    boundary_layers = [mp.Absorber(tABS,direction=mp.X),
                       mp.Absorber(tABS,direction=mp.Y),
                       mp.PML(tABS,direction=mp.Z,side=mp.High)]

    ORG = mp.Medium(index=1.75)
    ITO = mp.Medium(index=1.80)
    GLS = mp.Medium(index=1.45)

    from meep.materials import Al
    
    geometry = [mp.Block(material=GLS, size=mp.Vector3(mp.inf,mp.inf,tABS+tGLS), center=mp.Vector3(0,0,0.5*sz-0.5*(tABS+tGLS))),
                mp.Block(material=ITO, size=mp.Vector3(mp.inf,mp.inf,tITO), center=mp.Vector3(0,0,0.5*sz-tABS-tGLS-0.5*tITO)),
                mp.Block(material=ORG, size=mp.Vector3(mp.inf,mp.inf,tORG), center=mp.Vector3(0,0,0.5*sz-tABS-tGLS-tITO-0.5*tORG)),
                mp.Block(material=Al, size=mp.Vector3(mp.inf,mp.inf,tAl), center=mp.Vector3(0,0,0.5*sz-tABS-tGLS-tITO-tORG-0.5*tAl))]

    # current-source component
    if args.perp_dipole:
        src_cmpt = mp.Ez
        symmetries = [mp.Mirror(mp.X,+1), mp.Mirror(mp.Y,+1)]
    else:        
        src_cmpt = mp.Ex
        symmetries = [mp.Mirror(mp.X,-1), mp.Mirror(mp.Y,+1)]
        
    num_src = 10                 # number of point sources
    sources = [];
    for n in range(num_src):
        sources.append(mp.Source(mp.GaussianSource(fcen, fwidth=df), component=src_cmpt,
                                 center=mp.Vector3(0,0,0.5*sz-tABS-tGLS-tITO-0.4*tORG-0.2*tORG*n/num_src),
                                 amplitude=cmath.exp(2*cmath.pi*random.random()*1j)))

    if args.load_structure:
        epsilon_filename = 'oled_epsilon.h5'
        geometry = []
    else:
        epsilon_filename = ''
        
    sim = mp.Simulation(resolution=resolution,
                        cell_size=cell_size,
                        boundary_layers=boundary_layers,
                        geometry=geometry,
                        dimensions=3,
                        sources=sources,
                        force_complex_fields=True,
                        load_structure=epsilon_filename,
                        symmetries=symmetries)

    # number of frequency bins for DFT fields
    nfreq = 50
    
    # surround source with a six-sided box of flux planes
    srcbox_width = 0.05
    srcbox_top = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(center=mp.Vector3(0,0,0.5*sz-tABS-tGLS), size=mp.Vector3(srcbox_width,srcbox_width,0), direction=mp.Z, weight=+1))
    srcbox_bot = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(center=mp.Vector3(0,0,0.5*sz-tABS-tGLS-tITO-0.8*tORG), size=mp.Vector3(srcbox_width,srcbox_width,0), direction=mp.Z, weight=-1))
    srcbox_xp = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(center=mp.Vector3(0.5*srcbox_width,0,0.5*sz-tABS-tGLS-0.5*(tITO+0.8*tORG)), size=mp.Vector3(0,srcbox_width,tITO+0.8*tORG), direction=mp.X, weight=+1))
    srcbox_xm = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(center=mp.Vector3(-0.5*srcbox_width,0,0.5*sz-tABS-tGLS-0.5*(tITO+0.8*tORG)), size=mp.Vector3(0,srcbox_width,tITO+0.8*tORG), direction=mp.X, weight=-1))
    srcbox_yp = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(center=mp.Vector3(0,0.5*srcbox_width,0.5*sz-tABS-tGLS-0.5*(tITO+0.8*tORG)), size=mp.Vector3(srcbox_width,0,tITO+0.8*tORG), direction=mp.Y, weight=+1))
    srcbox_ym = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(center=mp.Vector3(0,-0.5*srcbox_width,0.5*sz-tABS-tGLS-0.5*(tITO+0.8*tORG)), size=mp.Vector3(srcbox_width,0,tITO+0.8*tORG), direction=mp.Y, weight=-1))

    # padding for flux box to fully capture waveguide mode
    fluxbox_dpad = 0.05

    # upward flux into glass substrate
    glass_flux = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(center=mp.Vector3(0,0,0.5*sz-tABS-(tGLS-fluxbox_dpad)), size = mp.Vector3(L,L,0), direction=mp.Z, weight=+1))

    # surround ORG/ITO waveguide with four-sided box of flux planes
    # NOTE: waveguide mode extends partially into Al cathode and glass substrate
    wvgbox_xp = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(size=mp.Vector3(0,L,fluxbox_dpad+tITO+tORG+fluxbox_dpad), direction=mp.X, center=mp.Vector3(0.5*L,0,0.5*sz-tABS-tGLS-0.5*(tITO+tORG)), weight=+1))
    wvgbox_xm = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(size=mp.Vector3(0,L,fluxbox_dpad+tITO+tORG+fluxbox_dpad), direction=mp.X, center=mp.Vector3(-0.5*L,0,0.5*sz-tABS-tGLS-0.5*(tITO+tORG)), weight=-1))
    wvgbox_yp = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(size=mp.Vector3(L,0,fluxbox_dpad+tITO+tORG+fluxbox_dpad), direction=mp.Y, center=mp.Vector3(0,0.5*L,0.5*sz-tABS-tGLS-0.5*(tITO+tORG)), weight=+1))
    wvgbox_ym = sim.add_flux(fcen, df, nfreq, mp.FluxRegion(size=mp.Vector3(L,0,fluxbox_dpad+tITO+tORG+fluxbox_dpad), direction=mp.Y, center=mp.Vector3(0,-0.5*L,0.5*sz-tABS-tGLS-0.5*(tITO+tORG)), weight=-1))
    
    sim.run(until_after_sources=mp.stop_when_fields_decayed(50, src_cmpt, mp.Vector3(0,0,0.5*sz-tABS-tGLS-tITO-0.5*tORG), 1e-8))

    if not args.load_structure:
        sim.dump_structure('oled_epsilon.h5')
    
    import numpy as np

    flux_srcbox_top = np.asarray(mp.get_fluxes(srcbox_top))
    flux_srcbox_bot = np.asarray(mp.get_fluxes(srcbox_bot))
    flux_srcbox_xp = np.asarray(mp.get_fluxes(srcbox_xp))
    flux_srcbox_xm = np.asarray(mp.get_fluxes(srcbox_xm))
    flux_srcbox_yp = np.asarray(mp.get_fluxes(srcbox_yp))
    flux_srcbox_ym = np.asarray(mp.get_fluxes(srcbox_ym))

    flux_wvgbox_xp = np.asarray(mp.get_fluxes(wvgbox_xp))
    flux_wvgbox_xm = np.asarray(mp.get_fluxes(wvgbox_xm))
    flux_wvgbox_yp = np.asarray(mp.get_fluxes(wvgbox_yp))
    flux_wvgbox_ym = np.asarray(mp.get_fluxes(wvgbox_ym))

    flux_glass = np.asarray(mp.get_fluxes(glass_flux))

    flux_total = flux_srcbox_top+flux_srcbox_bot+flux_srcbox_xp+flux_srcbox_xm+flux_srcbox_yp+flux_srcbox_ym
    flux_waveguide = flux_wvgbox_xp+flux_wvgbox_xm+flux_wvgbox_yp+flux_wvgbox_ym

    frac_glass = flux_glass/flux_total
    frac_waveguide = flux_waveguide/flux_total
    frac_aluminum = 1-frac_glass-frac_waveguide

    freqs = np.asarray(mp.get_flux_freqs(glass_flux))
    lambdas = 1/freqs
    lambdas_linear = np.linspace(lambda_min,lambda_max,nfreq)

    from scipy import interpolate

    g_linear = interpolate.interp1d(lambdas,frac_glass,kind='cubic')
    w_linear = interpolate.interp1d(lambdas,frac_waveguide,kind='cubic')
    a_linear = interpolate.interp1d(lambdas,frac_aluminum,kind='cubic')
    frac_glass_linear = g_linear(lambdas_linear)
    frac_waveguide_linear = w_linear(lambdas_linear)
    frac_aluminum_linear = a_linear(lambdas_linear)
    
    for j in range(nfreq):
        print("data:, {:.4f}, {:.6f}, {:.6f}, {:.6f}".format(lambdas_linear[j],
                                                             frac_glass_linear[j],
                                                             frac_waveguide_linear[j],
                                                             frac_aluminum_linear[j]))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-L', type=float, default=4.0, help='length of OLED (default: 4.0 um)')
    parser.add_argument('--perp_dipole', action='store_true', help='perpendicular dipole (default: False)')
    parser.add_argument('--load_structure', action='store_true', help='load structure from HDF5 file? (default: False)')
    args = parser.parse_args()
    main(args)
