# compute reflectance and transmittance of all diffracted orders
# for a planewave incident on a slanted grating with design based on
# T. Levola and P. Laakonen, Optics Express, Vol. 15, pp. 2067-74, 2007

import math
import cmath
import numpy as np
import argparse
import sys
import meep as mp

def main(args):
  resolution = args.res  # pixels/um

  dpml = 1.0             # PML length
  dair = 4.0             # padding length between PML and grating
  dsub = 3.0             # substrate thickness
  d = args.dd            # grating period
  h = args.hh            # grating height
  g = args.gg            # grating gap
  theta_1 = math.radians(args.theta_1)  # grating sidewall angle #1
  theta_2 = math.radians(args.theta_2)  # grating sidewall angle #2

  sx = dpml+dair+h+dsub+dpml
  sy = d

  cell_size = mp.Vector3(sx,sy,0)
  pml_layers = [mp.Absorber(thickness=dpml,direction=mp.X)]

  wvl = 0.5              # center wavelength
  fcen = 1/wvl           # center frequency
  df = 0.05*fcen         # frequency width

  ng = 1.716             # episulfide refractive index @ 0.532 um
  glass = mp.Medium(index=ng)

  if args.src_pol == 1:
    src_cmpt = mp.Ez
    eig_parity = mp.ODD_Z
  elif args.src_pol == 2:
    src_cmpt = mp.Hz
    eig_parity = mp.EVEN_Z
  else:
    sys.exit("error: src_pol={} is invalid".format(args.src_pol))
    
  # rotation angle of incident planewave source; CCW about Z axis, 0 degrees along +X axis
  theta_src = math.radians(args.src_angle)
  
  # k (in source medium) with correct length (plane of incidence: XY)
  k = mp.Vector3(math.cos(theta_src),math.sin(theta_src),0).scale(fcen)
  if theta_src == 0:
    k = mp.Vector3(0,0,0)
  
  def pw_amp(k,x0):
    def _pw_amp(x):
      return cmath.exp(1j*2*math.pi*k.dot(x+x0))
    return _pw_amp

  src_pt = mp.Vector3(-0.5*sx+dpml+0.2*dair,0,0)
  sources = [mp.Source(mp.GaussianSource(fcen,fwidth=df),
                       component=src_cmpt,
                       center=src_pt,
                       size=mp.Vector3(0,sy,0),
                       amp_func=pw_amp(k,src_pt))]

  sim = mp.Simulation(resolution=resolution,
                      cell_size=cell_size,
                      boundary_layers=pml_layers,
                      k_point=k,
                      sources=sources)

  refl_pt = mp.Vector3(-0.5*sx+dpml+0.7*dair,0,0)
  refl_flux = sim.add_flux(fcen, 0, 1, mp.FluxRegion(center=refl_pt, size=mp.Vector3(0,sy,0)))

  sim.run(until_after_sources=100)

  input_flux = mp.get_fluxes(refl_flux)
  input_flux_data = sim.get_flux_data(refl_flux)

  sim.reset_meep()

  geometry = [mp.Block(material=glass, size=mp.Vector3(dpml+dsub,mp.inf,mp.inf), center=mp.Vector3(0.5*sx-0.5*(dpml+dsub),0,0)),
              mp.Prism(material=glass,
                       height=mp.inf,
                       vertices=[mp.Vector3(0.5*sx-dpml-dsub,0.5*sy,0),
                                 mp.Vector3(0.5*sx-dpml-dsub-h,0.5*sy-h*math.tan(theta_2),0),
                                 mp.Vector3(0.5*sx-dpml-dsub-h,-0.5*sy+g-h*math.tan(theta_1),0),
                                 mp.Vector3(0.5*sx-dpml-dsub,-0.5*sy+g,0)])]

  sim = mp.Simulation(resolution=resolution,
                      cell_size=cell_size,
                      boundary_layers=pml_layers,
                      k_point=k,
                      sources=sources,
                      geometry=geometry)

  refl_flux = sim.add_flux(fcen, 0, 1, mp.FluxRegion(center=refl_pt, size=mp.Vector3(0,sy,0)))
  sim.load_minus_flux_data(refl_flux, input_flux_data)

  tran_pt = mp.Vector3(0.5*sx-dpml-0.5*dsub,0,0)
  tran_flux = sim.add_flux(fcen, 0, 1, mp.FluxRegion(center=tran_pt, size=mp.Vector3(0,sy,0)))

  sim.run(until_after_sources=500)

  kdom_tol = 1e-2
  angle_tol = 1e-6
  
  Rsum = 0
  Tsum = 0
  if theta_src == 0:
    nm_r = int(0.5*(np.floor((fcen-k.y)*d)-np.ceil((-fcen-k.y)*d)))       # number of reflected orders
    
    res = sim.get_eigenmode_coefficients(refl_flux, range(1,nm_r+1), eig_parity=eig_parity+mp.EVEN_Y)
    r_coeffs = res.alpha
    r_kdom = res.kdom
    for nm in range(nm_r):
      if r_kdom[nm].x > kdom_tol:
        r_angle = np.sign(r_kdom[nm].y)*math.acos(r_kdom[nm].x/fcen) if (r_kdom[nm].x % fcen > angle_tol) else 0
        Rmode = abs(r_coeffs[nm,0,1])**2/input_flux[0]
        print("refl: (even_y), {}, {:.2f}, {:.8f}".format(nm,math.degrees(r_angle),Rmode))
        Rsum += Rmode

    res = sim.get_eigenmode_coefficients(refl_flux, range(1,nm_r+1), eig_parity=eig_parity+mp.ODD_Y)
    r_coeffs = res.alpha
    r_kdom = res.kdom
    for nm in range(nm_r):
      if r_kdom[nm].x > kdom_tol:
        r_angle = np.sign(r_kdom[nm].y)*math.acos(r_kdom[nm].x/fcen) if (r_kdom[nm].x % fcen > angle_tol) else 0
        Rmode = abs(r_coeffs[nm,0,1])**2/input_flux[0]
        print("refl: (odd_y), {}, {:.2f}, {:.8f}".format(nm,math.degrees(r_angle),Rmode))
        Rsum += Rmode

    nm_t = int(0.5*(np.floor((fcen*ng-k.y)*d)-np.ceil((-fcen*ng-k.y)*d))) # number of transmitted orders

    res = sim.get_eigenmode_coefficients(tran_flux, range(1,nm_t+1), eig_parity=eig_parity+mp.EVEN_Y)
    t_coeffs = res.alpha
    t_kdom = res.kdom
    for nm in range(nm_t):
      if t_kdom[nm].x > kdom_tol:
        t_angle = np.sign(t_kdom[nm].y)*math.acos(t_kdom[nm].x/(ng*fcen)) if (t_kdom[nm].x % ng*fcen > angle_tol) else 0
        Tmode = abs(t_coeffs[nm,0,0])**2/input_flux[0]
        print("tran: (even_y), {}, {:.2f}, {:.8f}".format(nm,math.degrees(t_angle),Tmode))
        Tsum += Tmode

    res = sim.get_eigenmode_coefficients(tran_flux, range(1,nm_t+1), eig_parity=eig_parity+mp.ODD_Y)
    t_coeffs = res.alpha
    t_kdom = res.kdom
    for nm in range(nm_t):
      if t_kdom[nm].x > kdom_tol:
        t_angle = np.sign(t_kdom[nm].y)*math.acos(t_kdom[nm].x/(ng*fcen)) if (t_kdom[nm].x % ng*fcen > angle_tol) else 0
        Tmode = abs(t_coeffs[nm,0,0])**2/input_flux[0]
        print("tran: (odd_y), {}, {:.2f}, {:.8f}".format(nm,math.degrees(t_angle),Tmode))
        Tsum += Tmode      
  else:
    nm_r = int(np.floor((fcen-k.y)*d)-np.ceil((-fcen-k.y)*d))       # number of reflected orders
    res = sim.get_eigenmode_coefficients(refl_flux, range(1,nm_r+1), eig_parity=eig_parity)
    r_coeffs = res.alpha
    r_kdom = res.kdom
    for nm in range(nm_r):
      if r_kdom[nm].x > kdom_tol:
        r_angle = np.sign(r_kdom[nm].y)*math.acos(r_kdom[nm].x/fcen) if (r_kdom[nm].x % fcen > angle_tol) else 0
        Rmode = abs(r_coeffs[nm,0,1])**2/input_flux[0]
        print("refl:, {}, {:.2f}, {:.8f}".format(nm,math.degrees(r_angle),Rmode))
        Rsum += Rmode

    nm_t = int(np.floor((fcen*ng-k.y)*d)-np.ceil((-fcen*ng-k.y)*d)) # number of transmitted orders
    res = sim.get_eigenmode_coefficients(tran_flux, range(1,nm_t+1), eig_parity=eig_parity)
    t_coeffs = res.alpha
    t_kdom = res.kdom
    for nm in range(nm_t):
      if t_kdom[nm].x > kdom_tol:
        t_angle = np.sign(t_kdom[nm].y)*math.acos(t_kdom[nm].x/(ng*fcen)) if (t_kdom[nm].x % ng*fcen > angle_tol) else 0
        Tmode = abs(t_coeffs[nm,0,0])**2/input_flux[0]
        print("tran:, {}, {:.2f}, {:.8f}".format(nm,math.degrees(t_angle),Tmode))
        Tsum += Tmode

  print("total:, {:.6f}, {:.6f}, {:.6f}".format(Rsum,Tsum,Rsum+Tsum))
    
if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('-res', type=int, default=50, help='resolution (default: 50 pixels/um)')
  parser.add_argument('-src_pol', type=int, default=1, help='source polarization (1: Ez, 2: Hz, default: Ez)')
  parser.add_argument('-src_angle', type=float, default=0, help='source angle (default: 0 degrees)')
  parser.add_argument('-dd', type=float, default=0.6, help='grating periodicity (default: 0.6 um)')
  parser.add_argument('-gg', type=float, default=0.1, help='grating gap (default: 0.1 um)')
  parser.add_argument('-hh', type=float, default=0.4, help='grating height (default: 0.4 um)')
  parser.add_argument('-theta_1', type=float, default=12.8, help='grating sidewall angle #1 (default: 12.8 degrees)')
  parser.add_argument('-theta_2', type=float, default=27.4, help='grating sidewall angle #2 (default: 27.4 degrees)')
  args = parser.parse_args()
  main(args)
