#!/usr/bin/env python3

################################################################################
##                                                                            ##
##  This file is part of NCrystal (see https://mctools.github.io/ncrystal/)   ##
##                                                                            ##
##  Copyright 2015-2022 NCrystal developers                                   ##
##                                                                            ##
##  Licensed under the Apache License, Version 2.0 (the "License");           ##
##  you may not use this file except in compliance with the License.          ##
##  You may obtain a copy of the License at                                   ##
##                                                                            ##
##      http://www.apache.org/licenses/LICENSE-2.0                            ##
##                                                                            ##
##  Unless required by applicable law or agreed to in writing, software       ##
##  distributed under the License is distributed on an "AS IS" BASIS,         ##
##  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  ##
##  See the License for the specific language governing permissions and       ##
##  limitations under the License.                                            ##
##                                                                            ##
################################################################################

################################################################################################
####### Common code for all NCrystal cmdline scripts needing to import NCrystal modules ########
import sys
pyversion = sys.version_info[0:3]
_minpyversion=(3,6,0)
if pyversion < _minpyversion:
    raise SystemExit('Unsupported python version %i.%i.%i detected (needs %i.%i.%i or later).'%(pyversion+_minpyversion))
import os
import pathlib

def maybeThisIsConda():
    return ( os.environ.get('CONDA_PREFIX',None) or
             os.path.exists(os.path.join(sys.base_prefix, 'conda-meta')) )

def fixSysPathAndImportNCrystal( *, allowfail = False ):
    thisdir = pathlib.Path( __file__ ).parent
    def extract_cmake_pymodloc():
        p = thisdir / 'ncrystal-config'
        if not p.exists():
            return
        with p.open('rt') as fh:
            for i,l in enumerate(fh):
                if i==30:
                    break
                if l.startswith('#CMAKE_RELPATH_TO_PYMOD:'):
                    l = l[24:].strip()
                    return ( thisdir / l ) if l else None
    pml = extract_cmake_pymodloc()
    hack_syspath = pml and ( pml / 'NCrystal' / '__init__.py' ).exists()
    if hack_syspath:
        sys.path.insert(0,str(pml.absolute().resolve()))
    try:
        import NCrystal
    except ImportError:
        if allowfail:
            return
        msg = 'ERROR: Could not import the NCrystal Python module'
        if maybeThisIsConda():
            msg += ' (if using conda it might help to close your terminal and activate your environment again)'
        elif not hack_syspath:
            msg += ' (perhaps your PYTHONPATH is misconfigured)'
        raise SystemExit(msg)
    return NCrystal
################################################################################################
NC = fixSysPathAndImportNCrystal()

#########################################################################
########################### System setup ################################
#########################################################################

import argparse
import math

#in unit tests we dont display interactive images and we reduce cpu consumption
#by watching for special env var:
_unittest = bool(os.environ.get('NCRYSTAL_INSPECTFILE_UNITTESTS',''))

#Function for importing required python modules which may be missing, to provide
#a somewhat more helpful error to the user:
def import_optpymod(name):
    import importlib
    try:
        themod = importlib.import_module(name)
    except ImportError:
        msg = 'ERROR: Could not import a required python module: %s'%name
        if maybeThisIsConda() and name in ('matplotlib','numpy'):
                msg += ' (looks like you are using conda so you might solve this by running "conda install %s")'%name
        raise SystemExit(msg)
    return themod


#########################################################################
#########################################################################
#########################################################################

def parse_cmdline():
    descr="""

The most common usage of this tool is to load input data (usually .ncmat files)
with NCrystal (v%s) and plot resulting isotropic cross sections for thermal
neutrons. This is done by specifying one or more configurations ("cfg-strings"),
which indicates data names (e.g. file names) and optionally cfg parameters
(e.g. temperatures). Specifying more than one configuration, results in a single
comparison plot of the total scattering cross section based on the different
materials. Specifying just a single file, results in a more detailed cross
section plot as well as a 2D plot of generated scatter angles. Other behaviours
can be obtained by specifying flags as indicated below.

"""%NC.__version__

    descr=descr.strip()

    epilog="""
examples:
  $ %(prog)s Al_sg225.ncmat
  plot aluminium cross sections and scatter-angles versus neutron wavelength.
  $ %(prog)s Al_sg225.ncmat Ge_sg227.ncmat --common temp=200
  cross sections for aluminium and germanium at T=200K
  $ %(prog)s "Al_sg225.ncmat;dcutoff=0.1" "Al_sg225.ncmat;dcutoff=0.4" "Al_sg225.ncmat;dcutoff=0.8"
  effect of d-spacing cut-off on aluminium cross sections
  $ %(prog)s "Al_sg225.ncmat;temp=20" "Al_sg225.ncmat;temp=293.15" "Al_sg225.ncmat;temp=600"
  effect of temperature on aluminium cross sections
  $ %(prog)s "phases<0.65*Al_sg225.ncmat&0.35*MgO_sg225_Periclase.ncmat>;temp=100K"
  investigate multiphase material at 100K"""

    parser = argparse.ArgumentParser(description=descr,
                                     epilog=epilog,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('input_cfgs', metavar='CFGSTR', type=str, nargs='*',
                        help="""Input data (cfg-strings) to investigate. This
                        can just be simple file-names or full-blown cfg-strings
                        in the usual NCrystal syntax (see also examples
                        below).""")
    parser.add_argument('--version', action='version', version=str(NC.__version__))
    parser.add_argument('-d','--dump', action='count',default=0,
                        help='Dump derived information rather than displaying plots. Specify multiple times to increase verbosity.')
    parser.add_argument('--common','-c', metavar='CFG', type=str, default=[],
                        help='Common configuration items that will be applied to all input cfg strings',action='append')
    parser.add_argument('--coh_elas','--bragg', action='store_true',
                        help="""Only generate coherent-elastic (Bragg diffraction) component""")
    parser.add_argument('--incoh_elas', action='store_true',
                        help="""Only generate incoherent-elastic component""")
    parser.add_argument('--sans', action='store_true',
                        help="""Only generate SANS component""")
    parser.add_argument('--elastic', action='store_true',
                        help="""Only generate elastic components (including SANS)""")
    parser.add_argument('--inelastic', action='store_true',
                        help="""Only generate inelastic components""")
    parser.add_argument('-a','--absorption', action='store_true',
                        help="""Include absorption in cross section plots""")
    parser.add_argument('--phases', action='store_true',
                        help="""Show cross section breakdown of a single multiphase material by phase rather than physics process""")
    parser.add_argument('-x','--xrange', type=str,nargs='?',
                        help='Override plot range, e.g. "1e-5:1e2" or "0:10"')
    parser.add_argument('--logy', action='store_true',
                        help='Force y-axis to use logarithmic scale.')
    parser.add_argument('--liny', action='store_true',
                        help='Force y-axis to use linear scale.')
    parser.add_argument('-e','--energy', action='store_true',
                        help="""Show plots versus neutron energy rather than wavelength""")
    parser.add_argument('-p','--pdf', action='store_true',
                        help="""Generate PDF file rather than launching an interactive plot viewer.""")
    parser.add_argument('--test', action='store_true',
                        help="""Perform quick validation of NCrystal installation.""")
    parser.add_argument('--doc', action='count',default=0,
                        help="""Print documentation about the available cfg-str variables. Specify twice for more detailed help.""")
    dpi_default=200
    parser.add_argument('--dpi', default=-1,type=int,
                        help="""Change plot resolution. Set to 0 to leave matplotlib defaults alone.
                        (default value is %i, or whatever the NCRYSTAL_DPI env var is set to)."""%dpi_default)
    parser.add_argument('--cfg',action='store_true',
                        help='Print normalised cfg-string and dump meta-data about loaded physics processes.')
    parser.add_argument('--plugins', action='store_true',
                        help='List currently enabled loaded plugins.')
    parser.add_argument('-b','--browse', action='store_true',
                        help='List data available in standard locations (e.g. the files in the current directory or search path)')
    parser.add_argument('--extract', type=str, default=None, metavar="DATANAME",
                        help='''Extract contents of DATANAME (e.g. a file name) using the same lookup mechanism as used for data
                        specified in NCrystal cfg strings. This can therefore also be used to inspect
                        in-memory (or on-demand created) data.''')

    args=parser.parse_args()

    if args.logy and args.liny:
        parser.error('Do not specify both --liny and --logy')
    if not args.logy and not args.liny:
        args.logy = args.energy

    if args.xrange:
        try:
            _=args.xrange.split(':')
            if not len(_)==2:
                raise ValueError
            _ = ( float(_[0]), float(_[1]) )
            if not ( _[0]>=0.0 and _[1]>_[0]):
                raise ValueError
        except ValueError:
            parser.error(f'Invalid --xrange argument: "{args.xrange}"')
        args.xrange = _

    has_single_cfgstr = args.input_cfgs and len(args.input_cfgs)==1
    if args.cfg and not has_single_cfgstr:
        parser.error('Option --cfg requires exactly one cfg-string to be specified.')

    if args.phases and not has_single_cfgstr:
        parser.error('Option --phase requires exactly one cfg-string to be specified.')

    if args.dump and not has_single_cfgstr:
        parser.error('Option --dump requires exactly one cfg-string to be specified.')

    if args.extract or args.plugins or args.doc or args.browse:
        return args

    if args.dpi>3000:
        parser.error('Too high DPI value requested.')

    if args.dpi==-1:
        _=os.environ.get('NCRYSTAL_DPI',None)
        if _:
            try:
                _=int(_)
                if _<0:
                    raise ValueError
            except ValueError:
                print("ERROR: NCRYSTAL_DPI environment variable must be set to integer >=0")
                raise SystemExit
            if _>3000:
                parser.error('Too high DPI value requested via NCRYSTAL_DPI environment variable.')
            args.dpi=_
        else:
            args.dpi=dpi_default

    if args.test:
        if any((args.cfg,args.input_cfgs,args.dump,args.coh_elas,args.incoh_elas,args.sans,args.elastic,args.inelastic,args.absorption,args.pdf,args.phases)):
            parser.error('Do not specify other arguments with --test.')

    ncomp_select = sum((1 if _ else 0) for _ in (args.coh_elas,args.incoh_elas,args.sans,args.elastic,args.inelastic))
    if ncomp_select > 1:
        parser.error('Do not specify more than one of: --coh_elas/--bragg, --incoh_elas, --sans, --elastic or --inelastic.')

    if args.coh_elas: args.comp = 'coh_elas'
    elif args.incoh_elas: args.comp = 'incoh_elas'
    elif args.sans: args.comp = 'sans'
    elif args.elastic: args.comp = 'elastic'
    elif args.inelastic: args.comp = 'inelastic'
    else: args.comp = 'all'

    if args.absorption and ncomp_select>0:
        parser.error('Do not specify --absorption with either of: --coh_elas/--bragg, --incoh_elas, --sans, --elastic or --inelastic.')

    if args.dump and (ncomp_select>0 or args.absorption):
        parser.error('Do not specify --dump with either of: --coh_elas/--bragg, --incoh_elas, --sans, --elastic, --inelastic or --absorption.')

    if args.dump and len(args.input_cfgs)>1:
        parser.error('Do not specify more than one input cfg string with --dump [-d].')

    args.common=';'.join(args.common)
    return args

def create_ekins(npoints,range_override):
    if range_override:
        if range_override[0]<=0.0:
            range_override = ( range_override[1]*1e-10, range_override[1] )
        return NC._np_geomspace(*range_override,npoints)
    else:
        return NC._np_geomspace(1e-5,1e2,npoints)

def create_wavelengths(np,cfgs,npoints,range_override):
    if range_override:
        wls_min,wls_max = range_override
    else:
        bragg_thresholds = [c.braggthreshold() or 0.0 for c in cfgs]
        fallback = 10.0#materials with no bragg threshold
        max_bragg_threshold = ( max(bragg_thresholds) if bragg_thresholds else None ) or fallback
        wls_max = 1.2 * ( float(int(max_bragg_threshold*1.01+1.0)) if not math.isinf(max_bragg_threshold) else fallback )
        wls_min = 1e-4
    return NC._np_linspace( wls_min, wls_max, npoints )

_mpldpi=[None]
_pdffilename='ncrystal.pdf'
_npplt = None
def import_npplt(pdf=False):
    global _npplt
    if _npplt:
        #pdf par must be same as last call:
        assert bool(pdf)==bool(_npplt[2] is not None)
        return _npplt
    np = import_optpymod('numpy')
    mpl = import_optpymod('matplotlib')
    ##Commenting checks below since mpl.compare_versions is deprecated, and I am anyway not sure exactly which versions we support:
    ##if not mpl.compare_versions(mpl.__version__, '0.99.1.1'):
    ##    raise SystemExit("ERROR: Your version of matplotlib (%s) is too ancient to work - aborting plotting!"%mpl.__version__)
    ##if not mpl.compare_versions(mpl.__version__, '1.3'):
    ##    if not _unittest:
    ##        print("WARNING: Your version of matplotlib (%s) is unsupported - expect trouble! (needs at least version 1.3)."%mpl.__version__)

    if _mpldpi[0]:
        mpl.rcParams['figure.dpi']=_mpldpi[0]

    #ability to quit plot windows with Q:
    if 'keymap.quit' in mpl.rcParams and not 'q' in mpl.rcParams['keymap.quit']:
        mpl.rcParams['keymap.quit'] = tuple(list(mpl.rcParams['keymap.quit'])+['q','Q'])

    if _unittest or pdf:
        mpl.use('agg')
    if pdf:
        try:
            from matplotlib.backends.backend_pdf import PdfPages
        except ImportError:
            raise SystemExit("ERROR: Your installation of matplotlib does not have the required support for PDF output.")
    plt = import_optpymod('matplotlib.pyplot')

    _npplt = (np,plt,PdfPages(_pdffilename) if pdf else None)
    return _npplt

#functions for creating labels and title:

def _remove_common_keyvals(dicts):
    """remove any key from the passed dicts which exists with identical value in all
    the dicts. Returns a single dictionary with entries thus removed."""
    sets=[set((k,v) for k,v in list(d.items())) for d in dicts]
    common = dict(set.intersection(*sets)) if sets else {}
    for k in list(common.keys()):
        for d in dicts:
            d.pop(k,None)
    return common

def _serialise_cfg(part):
    #transform to list of tuples [(key,value),..] where entries can be
    #(parname,value) or compname/filename.

    mpstart = 'phases<'
    if part._cfg.cfgstr.startswith(mpstart):
        assert part._cfg.cfgstr.count('>')==1
        _=part._cfg.cfgstr[7:].split('>')
        main,trailing_common_cfg = _ if len(_)>1 else (_[0],'')
    else:
        _=part._cfg.cfgstr.split(';',1)
        main,trailing_common_cfg = _ if len(_)>1 else (_[0],'')
    l = [ ('[FILENAME]', main.strip() ), #using '[]' in special keys avoids clashes
          ('[COMPNAME]', part._compname ) ] #(cfg strs can't contain such chars)
    for e in trailing_common_cfg.split(';'):
        e=e.strip()
        if e == 'ignorefilecfg':
            raise SystemExit('ERROR: The ignorefilecfg keyword is no longer supported')
        elif e:
            k,v=e.split('=')
            l.append( (k.strip(),v.strip()) )
    return l

def _classify_differences(parts):
    l=[]
    for p in parts:
        l.append( dict(_serialise_cfg(p)) )
    common = _remove_common_keyvals(l)
    return common,l

def _cfgdict_to_str(cfgdict):
    fn = cfgdict.pop('[FILENAME]','')
    if '*' in fn:
        #multiphase, fix up a bit
        _=''
        for phase in fn.split('&'):
            fraction,phcfg = phase.split('*')
            if _:
                _ += ' + '
            multsymb = '\u00D7'
            _ += '%s%s(%s)'%(fraction,multsymb,phcfg)
        fn = '{%s}'%_
    o = [fn] if fn else []
    cn = cfgdict.pop('[COMPNAME]','')
    if cfgdict:
        o += [', '.join('%s=%s'%(k,v) for k,v in sorted(cfgdict.items()))]
    if cn:
        o += [ { 'coh_elas':'Coherent elastic',
                 'incoh_elas':'Incoherent elastic',
                 'elastic':'Elastic',
                 'inelastic':'Inelastic',
                 'sans':'SANS',
                 'absorption':'Absorption',
                 'all':'Total scattering',
                 'scattering+absorption':'Total scattering+Absorption'}[cn] ]
    return ' '.join(b%a for a,b in zip(o,['%s','[%s]','(%s)']))

def create_title_and_labels(parts):
    partscfg_common,partscfg_unique = _classify_differences(parts)
    return _cfgdict_to_str(partscfg_common),[(_cfgdict_to_str(uc) or 'default') for uc in partscfg_unique]

def _end_plot(plt,pdf):
    if _unittest:
        plt.savefig(open(os.devnull,'wb'),format='raw',dpi=10)
        plt.close()
    elif pdf:
        pdf.savefig()
        plt.close()
    else:
        plt.show()

def comp2cfgpars(comp):
    assert comp in ('coh_elas','incoh_elas','elastic','inelastic','sans','all')
    return { 'coh_elas' : 'incoh_elas=0;inelas=0;sans=0',
             'incoh_elas' : 'coh_elas=0;inelas=0;sans=0',
             'elastic' : 'inelas=0',
             'inelastic' : 'elas=0',
             'sans' : 'coh_elas=0;incoh_elas=0;inelas=0',
             'all' : '' }[comp]

def plot_xsect(cfgs,comp,absorption,pdf,versus_energy,xrange,logy,breakdown_by_phases):
    assert comp in ('coh_elas','incoh_elas','elastic','inelastic','sans','all')
    assert not absorption or comp=='all'
    scalefactors = None

    if breakdown_by_phases:
        assert len(cfgs)==1
        if cfgs[0].nPhases()==0:
            print("WARNING: --phases ignored for a single phase material")
            breakdown_by_phases = False
        else:
            mothercfg = cfgs[0]
            scalefactors = list(mothercfg.getChildPhaseNumberFraction(i) for i in range(mothercfg.nPhases()))
            assert abs(sum(scalefactors)-1.0)<1e-10
            cfgs = list(mothercfg.getChildPhaseCfg(i) for i in range(mothercfg.nPhases()))

    np,plt,pdf = import_npplt(pdf)
    if versus_energy:
        plt.xlabel('Neutron energy [eV]')
    else:
        plt.xlabel('Neutron wavelength [angstrom]')
    plt.ylabel('Cross section [barn/atom]')
    if len(cfgs)==1 and comp in ('all','elastic'):
        if comp=='all':
            parts=[cfgs[0].get_scatter('coh_elas'),cfgs[0].get_scatter('incoh_elas'),
                   cfgs[0].get_scatter('inelastic'),cfgs[0].get_scatter('sans')]
        else:
            assert comp=='elastic'
            parts=[cfgs[0].get_scatter('coh_elas'),cfgs[0].get_scatter('incoh_elas'),cfgs[0].get_scatter('sans')]
        if absorption:
            assert comp=='all'
            parts += [cfgs[0].get_absorption()]
    else:
        if absorption:
            assert comp=='all'
            parts=[c.get_totalxsect() for c in cfgs]
        else:
            parts=[c.get_scatter(comp) for c in cfgs]

    if breakdown_by_phases:
        sc_sans = mothercfg.get_scatter('sans')
        if not sc_sans._nullprocess:
            scalefactors += [1.0]
            parts += [sc_sans]

    if not breakdown_by_phases:
        #trim unused process types (but always show all in case of breakdown_by_phases):
        parts = [p for p in parts if not p._nullprocess]

    if not breakdown_by_phases:
        title,labels = create_title_and_labels(parts)
        if len(set(labels))!=len(labels):
            print("WARNING: Comparing identical setups?")
    else:
        title,labels = create_title_and_labels(parts)

    npts = 3000
    if versus_energy:
        ekins = create_ekins(npts,xrange)
    else:
        wavelengths = create_wavelengths(np,cfgs,npts,xrange)
        ekins = NC.wl2ekin(wavelengths)
    need_tot = (len(cfgs)==1 and len(parts)>1) or breakdown_by_phases
    xsects_tot = None
    max_len_label = 0

    #colors inspired by http://www.mulinblog.com/a-color-palette-optimized-for-data-visualization/
    col_red = "#F15854"
    partcols = [
        "#5DA5DA", # (blue)
        "#FAA43A", # (orange)
        "#60BD68", # (green)
        #"#B2912F", # (brown)
        "#B276B2", # (purple)
        #"#DECF3F", # (yellow)
        #"#F17CB0", # (pink)
        "#4D4D4D", # (gray)
        ]
    if not need_tot:
        partcols = [col_red]+partcols

    linewidth = 2.0

    xvar = ekins if versus_energy else wavelengths

    ydatarange = {}
    def update_ydatarange(xsects):
        ynz = xsects[np.nonzero(xsects)]
        y0nonzero,y0,y1 = ( ynz.min() if len(ynz) else None), xsects.min(), xsects.max()
        if y0 < ydatarange.get('ymin',float('inf')):
            ydatarange['ymin'] = y0
        if y0nonzero is not None and y0nonzero < ydatarange.get('ymin_nonzero',float('inf')):
            ydatarange['ymin_nonzero'] = y0nonzero
        if y1 > ydatarange.get('ymax',float('-inf')):
            ydatarange['ymax'] = y1

    for ipart,part in enumerate(parts):
        cfg = part._cfg
        compname = part._compname
        if part.isOriented():
            raise SystemExit("ERROR: This script can not produce quick cross-section plots for oriented processes (but you can still inspect the material with --dump)")
        xsects = part.crossSectionIsotropic(ekins)
        if scalefactors is not None:
            xsects *= scalefactors[ipart]
        if need_tot:
            if xsects_tot is None:
                xsects_tot = np.zeros(len(xsects))
            xsects_tot += xsects
        label=labels[ipart]
        max_len_label = max(max_len_label,len(label))
        ls={0:'-',1:'--',2:':'}.get(ipart//len(partcols),'-.')
        plt.plot(xvar,xsects,label=label,color=partcols[ipart%len(partcols)],lw=linewidth,ls=ls)
        update_ydatarange(xsects)
    if need_tot:
        if comp=='all' and len(cfgs)==1 and not breakdown_by_phases:
            #sanity check
            if absorption:
                xsects_tot_direct = cfgs[0].get_totalxsect().crossSectionIsotropic(ekins)
            else:
                xsects_tot_direct = cfgs[0].get_scatter('all').crossSectionIsotropic(ekins)
            xsects_discrepancy = xsects_tot_direct-xsects_tot
            discr_lvl = max(abs(xsects_discrepancy))
            if discr_lvl > 1e-10:
                print("WARNING: Discrepancy in breakdown into components detected (at the"
                      +f" {discr_lvl} level)!! Some plugins might be incorrectly programmed.")
                plt.plot(xvar,xsects_discrepancy,
                         label='WARNING: Discrepancy!',
                         color="cyan",lw=linewidth)

        update_ydatarange(xsects_tot)
        plt.plot(xvar,xsects_tot,
                 label={'all':'Total','elastic':'Total elastic'}[comp],
                 color="#F15854",lw=linewidth)#red-ish colour (see above)
    leg_fsize = 'large'
    if max_len_label > 40: leg_fsize = 'medium'
    if max_len_label > 60: leg_fsize = 'small'
    if max_len_label > 80: leg_fsize = 'smaller'
    try:
        if len(parts)>1:
            leg=plt.legend(loc='best',fontsize=leg_fsize)
            if hasattr(leg,'set_draggable'):
                leg.set_draggable(True)
            else:
                leg.draggable(True)
    except TypeError:
        plt.legend(loc='best')
    plt.grid()
    single_yval = bool(ydatarange.get('ymin','n/a')==ydatarange.get('ymax','n/a'))
    if single_yval and ydatarange.get('ymin','n/a')==0.0:
        if logy:
            print('WARNING: Could not set log scale since curves are 0.0 everywhere')
        plt.gca().set_ylim(0.0,1.0)
    elif logy:
        if ydatarange.get('ymin',1.0) <= 0.0:
            _=ydatarange.get('ymin_nonzero',ydatarange.get('ymax',1.0)*1e-10)
            if _:
                plt.gca().set_ylim(_,None)
        plt.gca().set_yscale('log')
    else:
        plt.gca().set_ylim(0,None)

    if versus_energy:
        plt.gca().set_xlim(ekins[0],ekins[-1])
        plt.gca().set_xscale('log')
    else:
        if wavelengths[0]*100<wavelengths[-1]:
            plt.gca().set_xlim(0.0,wavelengths[-1])
        else:
            plt.gca().set_xlim(wavelengths[0],wavelengths[-1])
    if title:
        if len(title)>30:
            plt.title(title,fontsize='smaller')
        else:
            plt.title(title)
    _end_plot(plt,pdf)

def plot_2d_scatangle(cfg,comp,pdf,versus_energy,xrange):
    assert comp in ('coh_elas','incoh_elas','elastic','inelastic','sans','all')
    part=cfg.get_scatter(comp)

    np,plt,pdf = import_npplt(pdf)

    #reproducible plots:
    import random
    random.seed(123456)

    #higher granularity wavelengths than for 1D plot to avoid artifacts:
    npts = 100 if _unittest else 30000
    if versus_energy:
        ekins = create_ekins(npts,xrange)
    else:
        wavelengths = create_wavelengths(np,[cfg],npts,xrange)
        ekins = NC.wl2ekin(wavelengths)

    #get title (label should be uninteresting for a single part):
    title,labels = create_title_and_labels([part])

    #First figure out how many points to put at each wavelength (or energy)
    if not part._nullprocess:
        xsects = part.crossSectionIsotropic(ekins)
        n2d = 100 if _unittest else 25000
        sumxs = np.sum(xsects)
        if sumxs:
            n_at_xvar = np.random.poisson(xsects*n2d/np.sum(xsects))
        else:
            n_at_xvar = np.zeros(len(xsects))
    else:
        n_at_xvar = np.zeros(len(wavelengths))

    xvar = ekins if versus_energy else wavelengths

    n2d=int(np.sum(n_at_xvar))#correction for random fluctuations
    if n2d>0:
        plot_angles = np.zeros(n2d)
        plot_xvar = np.zeros(n2d)
        j = 0
        for i,n in enumerate(n_at_xvar):
            i,n = int(i),int(n)
            evalue = xvar[i] if versus_energy else NC.wl2ekin(xvar[i])
            ekinfinal,mu = part.sampleScatterIsotropic(evalue,repeat=int(n))
            plot_angles[j:j+n] = np.arccos(mu)
            plot_xvar[j:j+n].fill(xvar[i])
            j+=n
        plot_angles *= (180./np.pi)
    else:
        plot_angles = None
        plot_xvar = None

    if plot_xvar is not None:
        try:
            plt.scatter(plot_xvar,plot_angles,alpha=0.2,marker='.',edgecolor=None,color='black',s=2,zorder=1)
        except ValueError:
            plt.scatter(plot_xvar,plot_angles,alpha=0.2,edgecolor=None,color='black',s=2,zorder=1)

    if versus_energy:
        plt.gca().set_xlim(ekins[0],ekins[-1])
        plt.gca().set_xscale('log')
    else:
        if wavelengths[0]*100<wavelengths[-1]:
            plt.gca().set_xlim(0.0,wavelengths[-1])
        else:
            plt.gca().set_xlim(wavelengths[0],wavelengths[-1])
    plt.gca().set_ylim(0,180)
    if versus_energy:
        plt.xlabel('Neutron energy [eV]')
    else:
        plt.xlabel('Neutron wavelength [angstrom]')
    plt.ylabel('Scattering angle [degrees]')
    if title:
        plt.title(title)
    plt.grid()
    _end_plot(plt,pdf)

class XSSum:
    #Combine scatter+absorption processes (hence no sampleScatterIsotropic method).
    def __init__(self,*processes):
        self._p = processes[:]
    def crossSectionIsotropic(self,ekin):
        return sum(p.crossSectionIsotropic(ekin) for p in self._p)
    def isOriented(self):
        return any(p.isOriented() for p in self._p)

class Cfg:
    def __init__(self,cfgstr, commoncfgstr):
        self._cfgstr = NC.normaliseCfg('%s;%s'%(cfgstr,commoncfgstr))
        self._sc = {}
        self._abs = None
        self._totxs = None
        self._info = None
        self._bt = 'not_init'
        self._iphase = None

    def nPhases(self):
        return len(self.get_info().phases)

    def getChildPhaseCfg(self,iphase):
        assert( iphase < self.nPhases() )
        childcfg = Cfg( self._cfgstr,'phasechoice=%i'%iphase )
        childcfg._iphase = iphase
        return childcfg

    def getChildPhaseNumberFraction(self,iphase):
        #fraction of atoms in phase
        i = self.get_info()
        i.numberdensity;
        volfrac,iph = i.phases[iphase]
        return volfrac*iph.numberdensity / i.numberdensity

    def braggthreshold(self):
        """in Aa (or None). Largest BT value for any crystalline phase."""
        if self._bt != 'not_init':
            return self._bt
        def largestbtrecursive(info):
            if info.isSinglePhase():
                return info.braggthreshold
            bts = [ largestbtrecursive(ph[1]) for ph in info.phases ]
            bts = [ e for e in bts if e ]
            return max(bts) if bts else None
        self._bt = largestbtrecursive( self.get_info() )
        return self._bt

    def get_scatter(self,comp = 'all', allowfail = False):
        if not comp in self._sc:
            extra_cfg = comp2cfgpars(comp)
            cstr = ';'.join([self._cfgstr,extra_cfg])
            try:
                sc = NC.createScatter(cstr)
            except NC.NCException:
                if allowfail:
                    return None
                else:
                    raise
            sc._nullprocess = sc.isNull()
            sc._cfg = self
            sc._compname = comp
            self._sc[comp] = sc
        return self._sc[comp]

    def get_absorption(self):
        if not self._abs:
            a = NC.createAbsorption(self._cfgstr)
            a._nullprocess = a.isNull()
            a._cfg = self
            a._compname = 'absorption'
            self._abs = a
        return self._abs

    def get_totalxsect(self):
        if not self._totxs:
            a,s = self.get_absorption(),self.get_scatter('all')
            t = XSSum(a,s)
            t._nullprocess = a._nullprocess and s._nullprocess
            t._cfg = self
            t._compname = 'scattering+absorption'
            self._totxs = t
        return self._totxs

    def get_info(self):
        if not self._info:
            self._info = NC.createInfo(self._cfgstr)
        return self._info

    @property
    def cfgstr(self):
        return self._cfgstr

def main():
    args = parse_cmdline()

    if args.doc:
        full = (args.doc >= 2)
        print(NC.generateCfgStrDoc( 'txt_full' if full else 'txt_short' ),end='')
        if not full:
            print("NOTE: Condensed output generated. Specify --doc twice for more details.")
        raise SystemExit

    if args.extract:
        s = NC.createTextData(args.extract).rawData
        if s is None:
            raise SystemExit('Error: unknown file "%s"'%args.extract)
        print(s,end='')
        raise SystemExit

    if args.plugins:
        NC.browsePlugins(dump=True)
        raise SystemExit

    if args.browse:
        NC.browseFiles(dump=True)
        raise SystemExit

    if args.test:
        NC.test()
        raise SystemExit

    if args.cfg or ( len(args.input_cfgs)==1 and not args.dump ):
        #Dump cfg debug info if requested or running with just 1 file.
        assert len(args.input_cfgs)==1
        origcfg=args.input_cfgs[0]
        if args.common:
            origcfg += f';{args.common}'
        print(f'==> Debugging cfg-string: "{origcfg}"')
        _ = comp2cfgpars(args.comp)
        if _:
            assert args.comp != 'all'
            print(f'==> Adding due to --{args.comp} flag specified: "{_}"')
            origcfg += ';' + _
        normcfg = NC.normaliseCfg(origcfg)
        print(f'==> Normalised cfg-string : "{normcfg}"')
        abs_obj=NC.createAbsorption(normcfg)
        sc_obj = NC.createScatter(normcfg)
        print( '==> Absorption process (code level objects):')
        proc_a = abs_obj
        proc_a.dump(' '*27)
        print( '==> Scattering process (code level objects):')
        proc_s = sc_obj
        proc_s.dump(' '*27)
        if args.cfg:
            raise SystemExit(0)

    _mpldpi[0] = args.dpi

    cfgs=[Cfg(e,args.common) for e in args.input_cfgs]
    cfgs_normalisedstrings = [c.cfgstr for c in cfgs]
    for cstr in set(cfgs_normalisedstrings):
        if cfgs_normalisedstrings.count(cstr)!=1:
            print("WARNING: Configuration specified more than once: \"%s\""%cstr)
    if args.dump:
        assert len(cfgs)==1
        cfgs[0].get_info().dump(verbose = int(args.dump)-1 )
        return
    if not cfgs:
        raise SystemExit('Error: nothing selected. Please run with --help for usage instructions.')

    plot_xsect( cfgs, comp  = args.comp, absorption = args.absorption, pdf=args.pdf,
                versus_energy=args.energy, xrange = args.xrange, logy = args.logy,
                breakdown_by_phases = args.phases )
    if len(cfgs)==1 and not bool(os.environ.get('NCRYSTAL_INSPECTFILE_NO2DSCATTER',0)):
        plot_2d_scatangle( cfgs[0], comp = args.comp, pdf=args.pdf, versus_energy=args.energy, xrange = args.xrange )
    if args.pdf:
        _,_,pdf = import_npplt(True)
        import datetime
        try:
            d = pdf.infodict()
        except AttributeError:
            d={}
        d['Title'] = 'Plots made with NCrystal-inspectfile from file%s %s'%('' if len(args.input_cfgs)==1 else 's',
                                                                            ','.join(os.path.basename(f) for f in args.input_cfgs))
        d['Author'] = 'NCrystal %s (via inspectfile)'%NC.__version__
        d['Subject'] = 'NCrystal plots'
        d['Keywords'] = 'NCrystal'
        d['CreationDate'] = datetime.datetime.today()
        d['ModDate'] = datetime.datetime.today()
        pdf.close()
        print("created %s"%_pdffilename)

if __name__=='__main__':
    main()

# TODO:
#   - allow tuning of n2d and alpha pars for 2D plot?
#   - show hkl values in 1d and 2d plot?
#   - option to show mfp rather than xsect?
#   - deltaE/wl_out plots as well (todo). Perhaps --wl=... option, resulting in plots for that wl.
