/* $Id: flat_combine.tcc,v 1.11 2013-08-13 13:16:27 cgarcia Exp $
 *
 * This file is part of the MOSCA library
 * Copyright (C) 2013 European Southern Observatory
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

/*
 * $Author: cgarcia $
 * $Date: 2013-08-13 13:16:27 $
 * $Revision: 1.11 $
 * $Name: not supported by cvs2svn $
 */


#ifndef FLAT_COMBINE_CPP
#define FLAT_COMBINE_CPP

#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
#include <stdexcept>
#include "flat_combine.h"
#include "mosca_image.h"
#include "reduce_method.h"
#include "image_utils.h"
#include "vector_utils.h"


/**
 * @brief
 *   Get a master spectroscopy flat 
 *
 * @param image_start An iterator with the first flat image
 * @param image_end   An iterator with the end of the list of flat images.
 *
 * @tparam T    The type of data of the images: float, double, etc...
 * @tparam Iter The type of iterator. If dereferenced it should return a 
 *              mosca::image object
 * TODO: Better interpolation of each value in the SED image
 */
template<typename T, typename Iter, typename CombineMethod>
std::auto_ptr<mosca::image> 
mosca::flat_combine_it(Iter flat_start, Iter flat_end, 
                       std::vector<mosca::detected_slit>& slits,
                       mosca::wavelength_calibration& wave_cal,
                       size_t smooth_size, CombineMethod comb_method)
{

    size_t n_pix_sed = flat_start->size_dispersion();
    /* TODO: Shouldn't be hard-coded DOUBLE */
    cpl_image * master_flat_im = cpl_image_new(flat_start->size_dispersion(), flat_start->size_spatial(), CPL_TYPE_DOUBLE); 
    cpl_image * master_flat_err = cpl_image_new(flat_start->size_dispersion(), flat_start->size_spatial(), CPL_TYPE_DOUBLE); 
    
    /* We work on a slit per slit basis */
    for(std::vector<mosca::detected_slit>::iterator slit_it = slits.begin();
        slit_it != slits.end() ; slit_it++)
    {

        int disp_bottom, spa_bottom, disp_top, spa_top;
        slit_it->get_extent_pix(disp_bottom, spa_bottom,
                                disp_top,    spa_top);

        double min_wave, max_wave;
        wave_cal.min_max_wave(min_wave, max_wave, flat_start->size_dispersion(),
                slit_it->get_position_spatial_corrected(), 
                slit_it->get_position_spatial_corrected() + slit_it->get_length_spatial_corrected());
        double mean_dispersion = wave_cal.mean_dispersion(flat_start->size_dispersion(),
                slit_it->get_position_spatial_corrected(), 
                slit_it->get_position_spatial_corrected() + slit_it->get_length_spatial_corrected());

        std::vector<mosca::image> slit_flats;
        
        for(Iter flat_it = flat_start; flat_it != flat_end ; flat_it++)
        {
            /* Trim the image to the slit area only */
            mosca::image slit_flat = flat_it->trim(disp_bottom, spa_bottom,
                                                   disp_top, spa_top);
            slit_flats.push_back(slit_flat);
        }
        
        /* Get the average flat. We don't care so much if areas 
         * which don't belong to this slit are present */
        mosca::image average_slit_flat = 
                mosca::imagelist_reduce(slit_flats.begin(), 
                                        slit_flats.end(), 
                                        mosca::reduce_mean());
        
        /* Divide each individual flat by the average flat. 
         * This removes the pixel to pixel variation, therefore only differences
         * in large scale variations in wavelength are retained */
        std::vector<mosca::image> pix2pix_slit_flats = 
                slit_flats / average_slit_flat;

        /* All the seds of all the flats */
        std::vector<std::vector<double> > seds;
        
        /* Get the spectral energy distribution for each individual flat */ 
        for(std::vector<mosca::image>::iterator flat_it =  pix2pix_slit_flats.begin(); 
            flat_it != pix2pix_slit_flats.end(); flat_it++)
        {
            mosca::image& slit_flat = *flat_it;
            double * flat_data = slit_flat.get_data<double>();
            double * flat_err = slit_flat.get_data_err<double>();

            /* Compute the SED for this flat */
            std::vector<double> sed(n_pix_sed, 0.);
            std::vector<double> sed_err(n_pix_sed, 0.);
            std::vector<int> nsum(n_pix_sed, 0);
            for(cpl_size j = 0; j< flat_it->size_dispersion(); j++)
            {
                /* When a pixel in column j is corrected from slit distortion, 
                 * it can happen that it corresponds to a column below or beyond
                 * the spectrum limits in the image. This margin_distor will
                 * ensure that +- that margin will be taken into account.
                 */
                for(cpl_size i = 0; i< flat_it->size_spatial(); i++)
                {
                    if(slit_it->within_trace((double)(j + disp_bottom),
                                             (double)(i + spa_bottom)))
                    {

                        double spatial_corrected = slit_it->spatial_correct
                                ((double)(j + disp_bottom),
                                 (double)(i + spa_bottom));
                        double wavelength = wave_cal.get_wave
                                (spatial_corrected, (double)(j + disp_bottom));
                        if(wavelength > min_wave && wavelength < max_wave)
                        {
                            /* TODO: j, i depends on the spectral axis */
                            size_t idx_sed = (wavelength - min_wave) / mean_dispersion - 0.5;
                            sed[idx_sed] += flat_data[j + slit_flat.size_dispersion() * i];
                            sed_err[idx_sed] += flat_err[j + slit_flat.size_dispersion() * i];
                            nsum[idx_sed] += 1;
                        }
                    }
                }
            }
            mosca::vector_divide(sed, sed_err, nsum);

            /* Smooth the SED */
            if(smooth_size > 1)
                mosca::vector_smooth(sed, sed_err, smooth_size);
            
            /* Add this SED to the list of SEDs */
            seds.push_back(sed);
        }

        /* Divide the original flats by its SED */
        std::vector<mosca::image>::iterator flat_it;
        std::vector<std::vector<double> >::iterator sed_it;
        for(sed_it = seds.begin(), flat_it =  slit_flats.begin(); 
                flat_it != slit_flats.end(); flat_it++, sed_it++)
        {
            mosca::image& slit_flat = *flat_it;
            double * flat_data = slit_flat.get_data<double>();
            double * flat_err = slit_flat.get_data_err<double>();
            for(cpl_size j = 0; j< flat_it->size_dispersion(); j++)
            {
                for(cpl_size i = 0; i< flat_it->size_spatial(); i++)
                {
                    if(slit_it->within_trace((double)(j + disp_bottom),
                            (double)(i + spa_bottom)))
                    {
                        double spatial_corrected = slit_it->spatial_correct
                                ((double)(j + disp_bottom),
                                        (double)(i + spa_bottom));
                        double wavelength = wave_cal.get_wave
                                (spatial_corrected, (double)(j + disp_bottom));
                        if(wavelength > min_wave && wavelength < max_wave)
                        {
                            size_t idx_sed = (wavelength - min_wave) / mean_dispersion - 0.5;
                            flat_data[j + slit_flat.size_dispersion() * i] /= (*sed_it)[idx_sed];
                            /* TODO: Use the error in sed_err */
                            flat_err[j + slit_flat.size_dispersion() * i] /= (*sed_it)[idx_sed];
                        }
                    }
                }
            }
        }

        /* Now we can stack the flats in the "usual" way. */
        mosca::image stacked_slit_flat_no_sed = 
                mosca::imagelist_reduce(slit_flats.begin(), 
                                        slit_flats.end(), comb_method);

        /* We compute the average SED */
        std::vector<double> avg_sed(n_pix_sed, 0.);
        for(size_t ipix = 0; ipix < n_pix_sed; ++ipix)
        {
            double sum = 0;
            for(size_t ised = 0; ised < seds.size(); ++ised)
                sum += seds[ised][ipix];
            avg_sed[ipix] = sum / seds.size();
        }
        
        /* Now we multiply the master slit flat by the average SED */
        double * average_slit_flat_no_sed_im = stacked_slit_flat_no_sed.get_data<double>();
        double * average_slit_flat_no_sed_err = stacked_slit_flat_no_sed.get_data_err<double>();
        for(cpl_size j = 0; j< stacked_slit_flat_no_sed.size_dispersion(); j++)
        {
            for(cpl_size i = 0; i< stacked_slit_flat_no_sed.size_spatial(); i++)
            {
                if(slit_it->within_trace((double)(j + disp_bottom),
                                         (double)(i + spa_bottom)))
                {
                    double spatial_corrected = slit_it->spatial_correct
                            ((double)(j + disp_bottom),
                             (double)(i + spa_bottom));
                    double wavelength = wave_cal.get_wave
                            (spatial_corrected, (double)(j + disp_bottom));
                    if(wavelength > min_wave && wavelength < max_wave)
                    {
                         size_t idx_sed = (wavelength - min_wave) / mean_dispersion - 0.5;
                         average_slit_flat_no_sed_im[j + stacked_slit_flat_no_sed.size_dispersion() * i] *= avg_sed[idx_sed];
                     /* TODO: Use the error in sed_err */
                         average_slit_flat_no_sed_err[j + stacked_slit_flat_no_sed.size_dispersion() * i] *= avg_sed[idx_sed];
                    }
                }
            }
        }
        
        
        /* The master slit flat is placed in the master flat */
        cpl_image_copy(master_flat_im, stacked_slit_flat_no_sed.get_cpl_image(),
                       1, spa_bottom);
        cpl_image_copy(master_flat_err, stacked_slit_flat_no_sed.get_cpl_image_err(),
                       1, spa_bottom);
    }
    
    std::auto_ptr<mosca::image> 
        master_flat(new mosca::image(master_flat_im, master_flat_err));
    
    return master_flat;
}

/**
 * @brief
 *   Get a master spectroscopy flat 
 *
 * @param image_list A vector with all the flats.
 * @param image_end   An iterator with the end of the list of flat images.
 *
 * @tparam T    The type of data of the images: float, double, etc...
 * @tparam Iter The type of iterator. If dereferenced it should return a 
 *              mosca::image object
 *
 */
template<typename T, typename CombineMethod>
std::auto_ptr<mosca::image> mosca::flat_combine
(std::vector<mosca::image>& image_list, 
 std::vector<mosca::detected_slit>& slits,
 mosca::wavelength_calibration& wave_cal,
 size_t smooth_size, CombineMethod comb_method)
{
    typedef std::vector<mosca::image>::iterator iter_type;
    return mosca::flat_combine_it<T, iter_type, CombineMethod >
        (image_list.begin(), image_list.end(), 
                slits, wave_cal, smooth_size, comb_method);
}


#endif
