// src/components/PredictionTools.tsx

import React, { useEffect } from 'react'
import { Button, Card } from 'react-bootstrap'
import { PredictionMetadata, Prediction } from '../../types/prediction'
import { useAppDispatch, useAppSelector } from '../../store/hooks'
import { addPrediction, clearPredictions, selectSelectedPredictions } from '../../store/slices/predictionSlice'
import { selectSelectedModel } from '../../store/slices/modelSlice'
import { useQueryPredictionsMutation } from '../../services/predictionApi'
import { useGetTagsQuery } from '../../services/tagApi'
import { useCreateAnnotationMutation } from '../../services/annotationApi'
import { invalidateTags } from '../../services/contentApi'
import { CreateBBoxAnnotationReq, defaultProjectQueryRegex } from '../../types/requests'
import { AnnotationDataBoundingBox } from '../../types/annotation'

import { selectActiveTab, selectSelectedProject } from '../../store/slices/projectSlice'

const PredictionTools: React.FC = () => {
  const dispatch = useAppDispatch()
  const project = useAppSelector(selectSelectedProject)
  const selectedPredictions = useAppSelector(selectSelectedPredictions)
  const selectedModel = useAppSelector(selectSelectedModel)
  const [createAnnotation] = useCreateAnnotationMutation()
  const [queryPredictions, { data: predictions, isLoading: isPredictionsLoading }] = useQueryPredictionsMutation()

  const activeTab = useAppSelector(selectActiveTab)

  const { data: tags } = useGetTagsQuery({
    project_id: selectedModel?.projectid || '',
    dataset_id: project?.datasetid || '',
  })

  useEffect(() => {
    if (selectedModel?.batch?.status.toUpperCase().startsWith('COMPLETE')) {
      const queryRegex = defaultProjectQueryRegex
      // queryRegex.filters[0].key = 'modelid' // already default
      queryRegex.filters[0].value = selectedModel.id
      queryRegex.limit = 100
      queryRegex.page = 0
      queryPredictions(queryRegex)
    }
  }, [selectedModel, queryPredictions])

  const handleSelectAll = () => {
    predictions?.predictions?.forEach((prediction) => {
      if (!selectedPredictions.some((p) => p.id === prediction.id)) {
        dispatch(addPrediction(prediction))
      }
    })
  }

  const handleClearSelections = () => {
    dispatch(clearPredictions())
  }

  const handleAcceptSelections = async () => {
    if (project?.annotation_type === 'classification') {
      // const newClassifications =
      selectedPredictions.forEach((prediction: Prediction) => {
        const tagIds = new Set<string>()

        prediction.predictions.forEach((metadata: PredictionMetadata) => {
          const tagId = tags?.tags?.find((tag) => tag.name === metadata.class_name)?.id
          if (tagId) {
            tagIds.add(tagId)
          }
        })

        const newAnnotation: CreateBBoxAnnotationReq = {
          content_id: prediction.contentid,
          dataset_id: project?.datasetid || '',
          metadata: undefined,
          project_id: selectedModel?.projectid || '',
          tag_id: Array.from(tagIds),
        }

        createAnnotation(newAnnotation)
      })
      dispatch(clearPredictions())
    } else if (project?.annotation_type === 'bounding_box') {
      selectedPredictions.forEach((prediction: Prediction) => {
        const tagIds = new Set<string>()
        const newBoundingBoxes: AnnotationDataBoundingBox[] = []

        prediction.predictions.forEach((annoMetadata) => {
          const { bounding_boxes, class_name } = annoMetadata
          const tagId = tags?.tags?.find((tag) => tag.name === class_name)?.id
          if (tagId) {
            tagIds.add(tagId)
          }
          if (bounding_boxes && tagId) {
            newBoundingBoxes.push({
              name: tagId,
              xmax: bounding_boxes.xmax,
              xmin: bounding_boxes.xmin,
              ymax: bounding_boxes.ymax,
              ymin: bounding_boxes.ymin,
            })
          }
        })

        const newAnnotation: CreateBBoxAnnotationReq = {
          content_id: prediction.contentid,
          dataset_id: project?.datasetid || '',
          metadata: {
            bounding_boxes: newBoundingBoxes,
          },
          project_id: selectedModel?.projectid || '',
          tag_id: Array.from(tagIds),
        }

        createAnnotation(newAnnotation)
      })
    }

    await dispatch(invalidateTags([{ type: 'Content' }]))
  }

  return (
    <>
      {activeTab !== 'annotated' && activeTab !== 'unannotated' && (
        <Card className="prediction-tools mt-3 mb-3">
          <Card.Header>
            <Card.Title className="mb-0">Select Predictions</Card.Title>
          </Card.Header>
          <Card.Body>
            <div className="d-flex justify-content-between">
              <Button variant="primary" onClick={handleSelectAll} disabled={isPredictionsLoading}>
                All
              </Button>
              <Button variant="outline-secondary" onClick={handleClearSelections}>
                Clear
              </Button>
              <Button variant="success" onClick={handleAcceptSelections}>
                Accept ({selectedPredictions.length})
              </Button>
            </div>
          </Card.Body>
        </Card>
      )}
    </>
  )
}

export default PredictionTools
