import os
from flask import render_template, redirect, url_for, flash, request, current_app, session
from flask_login import login_required, current_user
from werkzeug.utils import secure_filename
from app.extensions import db
from app.blueprints.datasets import bp
from app.blueprints.datasets.forms import UploadForm, ColumnMappingForm
from app.blueprints.datasets.parsers import parse_csv_preview, ingest_csv
from app.models.dataset import Dataset, DataPoint


@bp.route('/datasets')
@login_required
def list_datasets():
    datasets = Dataset.query.filter_by(user_id=current_user.id)\
        .order_by(Dataset.created_at.desc()).all()
    return render_template('list.html', datasets=datasets)


@bp.route('/datasets/upload', methods=['GET', 'POST'])
@login_required
def upload():
    form = UploadForm()
    if form.validate_on_submit():
        f = form.file.data
        filename = secure_filename(f.filename)
        upload_dir = current_app.config['UPLOAD_FOLDER']
        filepath = os.path.join(upload_dir, filename)
        f.save(filepath)

        # Create dataset record
        dataset = Dataset(
            user_id=current_user.id,
            name=form.name.data,
            description=form.description.data,
            original_filename=filename,
            upload_status='mapping',
        )
        db.session.add(dataset)
        db.session.commit()

        # Store in session for column mapping step
        session['pending_dataset_id'] = dataset.id
        session['pending_filepath'] = filepath

        return redirect(url_for('datasets.map_columns'))
    return render_template('upload.html', form=form)


@bp.route('/datasets/map-columns', methods=['GET', 'POST'])
@login_required
def map_columns():
    dataset_id = session.get('pending_dataset_id')
    filepath = session.get('pending_filepath')
    if not dataset_id or not filepath:
        return redirect(url_for('datasets.upload'))

    preview = parse_csv_preview(filepath)
    cols = preview['columns']
    choices = [(c, c) for c in cols]

    form = ColumnMappingForm()
    form.timestamp_col.choices = choices
    form.lat_col.choices = choices
    form.lon_col.choices = choices
    form.animal_id_col.choices = choices

    # Try to auto-detect columns
    if request.method == 'GET':
        for col in cols:
            cl = col.lower()
            if 'time' in cl or 'date' in cl or 'datetime' in cl:
                form.timestamp_col.data = col
            if 'lat' in cl:
                form.lat_col.data = col
            if 'lon' in cl or 'lng' in cl:
                form.lon_col.data = col
            if 'animal' in cl or 'id' in cl or 'individual' in cl or 'tag' in cl:
                form.animal_id_col.data = col

    if form.validate_on_submit():
        column_mapping = {
            'timestamp': form.timestamp_col.data,
            'lat': form.lat_col.data,
            'lon': form.lon_col.data,
            'animal_id': form.animal_id_col.data,
        }

        dataset = Dataset.query.get(dataset_id)
        dataset.column_mapping = column_mapping

        try:
            df = ingest_csv(filepath, column_mapping)

            # Bulk insert data points
            points = []
            for _, row in df.iterrows():
                points.append(DataPoint(
                    dataset_id=dataset.id,
                    animal_id=row['animal_id'],
                    timestamp=row['timestamp'],
                    lat=row['lat'],
                    lon=row['lon'],
                    sensor_data=row.get('sensor_data'),
                    speed=row.get('speed'),
                    step_length=row.get('step_length'),
                    turning_angle=row.get('turning_angle'),
                ))
            db.session.bulk_save_objects(points)

            # Update dataset summary
            dataset.num_animals = df['animal_id'].nunique()
            dataset.num_fixes = len(df)
            dataset.time_start = df['timestamp'].min()
            dataset.time_end = df['timestamp'].max()
            dataset.bbox_min_lat = df['lat'].min()
            dataset.bbox_max_lat = df['lat'].max()
            dataset.bbox_min_lon = df['lon'].min()
            dataset.bbox_max_lon = df['lon'].max()
            dataset.upload_status = 'ready'
            db.session.commit()

            # Cleanup session
            session.pop('pending_dataset_id', None)
            session.pop('pending_filepath', None)

            flash(f'Dataset imported: {dataset.num_fixes} fixes from {dataset.num_animals} animals.', 'success')
            return redirect(url_for('datasets.view', dataset_id=dataset.id))

        except Exception as e:
            dataset.upload_status = 'error'
            db.session.commit()
            flash(f'Import error: {e}', 'error')
            return redirect(url_for('datasets.upload'))

    return render_template('map_columns.html', form=form, preview=preview)


@bp.route('/datasets/<int:dataset_id>')
@login_required
def view(dataset_id):
    dataset = Dataset.query.get_or_404(dataset_id)
    if dataset.user_id != current_user.id:
        flash('Access denied.', 'error')
        return redirect(url_for('datasets.list_datasets'))

    animals = db.session.query(DataPoint.animal_id)\
        .filter_by(dataset_id=dataset.id)\
        .distinct().all()
    animal_ids = [a[0] for a in animals]

    return render_template('view.html', dataset=dataset, animal_ids=animal_ids)


@bp.route('/datasets/<int:dataset_id>/delete', methods=['POST'])
@login_required
def delete(dataset_id):
    dataset = Dataset.query.get_or_404(dataset_id)
    if dataset.user_id != current_user.id:
        flash('Access denied.', 'error')
        return redirect(url_for('datasets.list_datasets'))

    db.session.delete(dataset)
    db.session.commit()
    flash('Dataset deleted.', 'success')
    return redirect(url_for('datasets.list_datasets'))
