import { useTheme } from '@mui/material'
import chroma from 'chroma-js'
import ReactECharts, { EChartsOption } from 'echarts-for-react'
import { memo, useEffect, useRef } from 'react'
import ChartData from '../../model/chart/ChartData'
import ChartType from '../../model/chart/ChartType'
import GroupedChartData, { GroupedDataPoint } from '../../model/chart/GroupedChartData'
import SankeyChartData from '../../model/chart/SankeyChartData'
import StandardChartData, { StandardDataPoint } from '../../model/chart/StandardChartData'
import ColorUtil from '../../util/ColorUtil'
import StringUtil from '../../util/StringUtil'
import { useDynamicGridItemContext } from '../layout/DynamicGrid'

export type Padding = {
  left?: number
  right?: number
  top?: number
  bottom?: number
}

export type KeyType = 'category' | 'time'

export interface EChartProps {
  data: ChartData
  type: ChartType
  keyFormatter?: (value: string) => string
  valueFormatter?: (value: number) => string
  keyAxis?: {
    show?: boolean
    name?: string
    nameLocation?: string
    nameGap?: number
    type?: KeyType
  }
  valueAxis?: { show?: boolean; name?: string; nameLocation?: string; nameGap?: number }
  selectedValues?: string[]
  onSelectValue?: (value: string) => void
  color?: string
  tooltip?: EChartsOption['tooltip']
  constrainedValues?: boolean
  padding?: Padding
  maxKeyLength?: number
  animation?: boolean
  markings?: { name: string; start: any; end?: any; tooltip?: string }[]
  initiallySelectedLegends?: string[]
  persistentLegends?: string[]
  onLegendClick?: (name: string, selected: boolean) => void
}

const MemoizedReactECharts = memo(ReactECharts)

export default function EChart({
  data,
  type,
  keyFormatter = (value: string) => value,
  valueFormatter = (value: number) => StringUtil.numberFormat(value),
  keyAxis,
  valueAxis,
  selectedValues,
  onSelectValue,
  color,
  tooltip,
  constrainedValues,
  padding,
  maxKeyLength = 12,
  animation = true,
  markings,
  initiallySelectedLegends,
  persistentLegends,
  onLegendClick,
}: EChartProps) {
  const theme = useTheme()
  const echartRef = useRef<ReactECharts | null>(null)

  // Whenever resizeKey is updated (which means that the chart's parent <Card> was resized), chart should be resized.
  const { key, currentKey, resizeKey } = useDynamicGridItemContext()

  /**
   * Whenever the element is resized with ReactGridLayout, resize the chart as well.
   */
  useEffect(() => {
    // If another element triggered the resize event, ignore it
    if (currentKey !== key) return
    echartRef.current?.getEchartsInstance().resize()
  }, [key, currentKey, resizeKey])

  function handleKeyClick(key: string) {
    if (key.includes(' > ')) {
      key = key.split(' > ').join('->')
    }

    onSelectValue && onSelectValue(key)
  }

  function handleLegendClick(name: string, selected: boolean) {
    // Check if the legend is persistent
    if (persistentLegends && persistentLegends.includes(name) && !selected) {
      echartRef.current?.getEchartsInstance().dispatchAction({
        type: 'legendSelect',
        name,
      })

      return
    }

    onLegendClick && onLegendClick(name, selected)
  }

  function getColorForKey(key: string, index: number, total: number): string {
    if (total === 2) {
      const baseColor = chroma(theme.palette.primary.main).hex()
      const lighterColor = chroma(baseColor).brighten(2).hex()

      return index === 0 ? baseColor : lighterColor
    } else {
      const colorScheme = ColorUtil.generateColorScheme(theme.palette.primary.main, total)
      const isSelectedByClick = selectedValues && selectedValues.includes(key)
      const isAnySelected = selectedValues && selectedValues.length > 0

      return !isAnySelected || isSelectedByClick
        ? color ?? colorScheme[index % colorScheme.length]
        : 'gray'
    }
  }

  function getPadding(defaultPadding: Padding): Padding {
    return {
      left: padding?.left ?? defaultPadding.left,
      right: padding?.right ?? defaultPadding.right,
      top: padding?.top ?? defaultPadding.top,
      bottom: padding?.bottom ?? defaultPadding.bottom,
    }
  }

  function getOptions(): EChartsOption {
    const globalSharedOptions = {
      color: theme.palette.primary.main,
      animation: animation,
      tooltip: {
        trigger: 'item',
        confine: true,
        formatter: (params: any) => {
          if (Array.isArray(params.value)) {
            return `${keyFormatter(params.name)}: ${valueFormatter(params.value[1])}`
          }
          return `${keyFormatter(params.name)}: ${valueFormatter(params.value)}`
        },
        ...tooltip,
      },
    }

    const standardSharedOptions = {
      series: {
        emphasis: {
          itemStyle: {
            color: theme.palette.primary.dark,
          },
        },
        barCategoryGap: '5%',
      },
    }

    // Helpers
    function getKeyAxisOptions(keys: string[], show = true) {
      return {
        ...keyAxis,
        show: keyAxis?.show ?? show,
        type: keyAxis?.type ?? 'category',
        data: keys,
        axisLabel: {
          formatter: (value: string) => StringUtil.shorten(keyFormatter(value), maxKeyLength),
          rotate: 45,
        },
      }
    }

    function getValueAxisOptions(show = true) {
      const commonOptions = {
        ...valueAxis,
        show: keyAxis?.show ?? show,
        type: 'value',
        axisLabel: {
          formatter: (value: number) => valueFormatter(value),
        },
        name: valueAxis?.name,
        nameLocation: valueAxis?.nameLocation ?? 'center',
        nameGap: valueAxis?.name ? valueAxis?.nameGap ?? 40 : 0,
      }

      return constrainedValues
        ? { ...commonOptions, min: 'dataMin', max: 'dataMax' }
        : { ...commonOptions }
    }

    function getStandardSeries(points: StandardDataPoint[], type: 'bar' | 'line') {
      return {
        ...standardSharedOptions.series,
        ...getMarkingOptions(),
        data: points.map((point, index) => ({
          name: point.key,
          value: point.value,
          itemStyle: { color: getColorForKey(point.key, index, points.length) },
        })),
        type: type,
        barCategoryGap: '5%',
      }
    }

    function getGroupedSeries(groups: string[], points: GroupedDataPoint[], type: 'bar' | 'line') {
      return groups.map((group, index) => {
        // Only apply markings if it's the first group
        const markingOptions = () => (index === 0 ? getMarkingOptions() : {})

        const color = getColorForKey(group, index, groups.length)
        return {
          ...markingOptions(),
          name: group,
          type: type,
          data: points
            .filter((point) => point.group === group)
            .map((point) => {
              let key: string | number = new Date(point.key).getTime()
              if (isNaN(key)) {
                key = point.key
              }

              return {
                name: point.key,
                value: [point.key, point.value],
              }
            }),
          color: color,
          itemStyle: { color: color },
          lineStyle: { color: color },
        }
      })
    }

    function getMarkingOptions() {
      if (!markings) return {}

      const lines = []
      const areas = []

      for (let i = 0; i < markings.length; i++) {
        const marking = markings[i]
        const labelFormatter = () => marking.name
        if (marking.end && marking.end !== marking.start) {
          areas.push([
            {
              name: i,
              xAxis: marking.start,
              label: {
                position: 'inside',
                formatter: labelFormatter,
              },
            },
            { xAxis: marking.end },
          ])
        } else {
          lines.push({
            name: i,
            xAxis: marking.start,
            label: {
              position: 'middle',
              formatter: labelFormatter,
            },
          })
        }
      }

      const sharedMarkingOptions = {
        animation: false,
        tooltip: {
          formatter: (params: any) => {
            const index = params.name
            return markings[index].tooltip
          },
        },
      }

      const object: any = {}
      if (lines.length > 0)
        object.markLine = { ...sharedMarkingOptions, data: lines, symbol: 'none' }
      if (areas.length > 0)
        object.markArea = {
          ...sharedMarkingOptions,
          data: areas,
          emphasis: {
            label: {
              position: 'inside',
            },
          },
        }

      return object
    }

    // Chart-specific options
    if (data instanceof StandardChartData) {
      const keys = data.getKeys()
      const valueData = data.points.map((point, index) => {
        return {
          value: point.value,
          name: point.key,
          itemStyle: { color: getColorForKey(point.key, index, keys.length) },
        }
      })

      switch (type) {
        case ChartType.bar:
        case ChartType.sankey:
          return {
            ...globalSharedOptions,
            ...standardSharedOptions,
            xAxis: getKeyAxisOptions(keys),
            yAxis: getValueAxisOptions(),
            series: [getStandardSeries(data.points, 'bar')],
            grid: getPadding({ left: 50, right: 15, top: 15, bottom: 60 }),
          }
        case ChartType.horizontalBar:
          return {
            ...globalSharedOptions,
            ...standardSharedOptions,
            xAxis: getKeyAxisOptions(keys),
            yAxis: getValueAxisOptions(),
            series: [getStandardSeries(data.points, 'bar')],
            grid: getPadding({ left: 70, right: 15, top: 5, bottom: 60 }),
          }
        case ChartType.donut:
        case ChartType.pie:
          return {
            ...globalSharedOptions,
            ...standardSharedOptions,
            xAxis: getKeyAxisOptions(keys, false),
            yAxis: getValueAxisOptions(false),
            series: [
              {
                ...standardSharedOptions.series,
                data: valueData,
                type: 'pie',
                label: {
                  show: true,
                  position: 'outside',
                  formatter: '{b}',
                },
                labelLine: {
                  show: true,
                  length: 10,
                  length2: 10,
                },
                ...(type === ChartType.donut ? { radius: ['50%', '70%'] } : {}),
              },
            ],
          }
        case ChartType.line:
          return {
            ...globalSharedOptions,
            ...standardSharedOptions,
            xAxis: getKeyAxisOptions(keys),
            yAxis: getValueAxisOptions(),
            series: [getStandardSeries(data.points, 'line')],
            grid: getPadding({ left: 35, right: 10, top: 5, bottom: 40 }),
          }
        case ChartType.treemap:
          const treemapData = data.points.map((point, index) => ({
            value: point.value,
            name: point.key,
            itemStyle: {
              color: getColorForKey(point.key, index, data.points.length),
              borderColor: 'white',
              borderWidth: 2,
            },
          }))
          return {
            ...globalSharedOptions,
            ...standardSharedOptions,
            series: [
              {
                ...standardSharedOptions.series,
                type: 'treemap',
                roam: false,
                breadcrumb: { show: false },
                data: treemapData,
              },
            ],
          }
        case ChartType.levels:
          return {
            ...globalSharedOptions,
            yAxis: { ...getKeyAxisOptions(keys), inverse: true },
            xAxis: { ...getValueAxisOptions(), inverse: true },
            series: [
              {
                ...standardSharedOptions.series,
                data: valueData,
                type: 'bar',
                label: {
                  show: true,
                  position: 'right',
                  formatter: function (params: { dataIndex: number }) {
                    return keys[params.dataIndex]
                  },
                  fontSize: 25,
                  fontWeight: 'bold',
                },
              },
            ],
            grid: getPadding({ left: 30, right: 50, top: 5, bottom: 40 }),
          }

        case ChartType.levelsFull:
          return {
            ...globalSharedOptions,
            ...standardSharedOptions,
            xAxis: { ...getKeyAxisOptions(keys), show: false },
            yAxis: { ...getValueAxisOptions(), show: false },
            series: [{ ...standardSharedOptions.series, data: valueData, type: 'line' }],
            grid: getPadding({ left: 10, right: 10, top: 10, bottom: 10 }),
          }
        default:
      }
    }

    if (data instanceof GroupedChartData) {
      const keys = Array.from(new Set(data.points.map((point) => point.key)))
      const groups = Array.from(new Set(data.points.map((point) => point.group)))

      const groupedSharedOptions = {
        xAxis: getKeyAxisOptions(keys),
        yAxis: getValueAxisOptions(),
      }

      // Set selected to true and groups that are not present to false
      const selectedGroup: { [key: string]: boolean } = {
        ...initiallySelectedLegends?.reduce((acc, legend) => ({ ...acc, [legend]: true }), {}),
      }
      groups.forEach((group) => {
        if (!selectedGroup[group]) {
          selectedGroup[group] = false
        }
      })

      switch (type) {
        case ChartType.groupedBar:
          const groupedBarSeries = getGroupedSeries(groups, data.points, 'bar')
          return {
            ...globalSharedOptions,
            ...groupedSharedOptions,
            series: getGroupedSeries(groups, data.points, 'bar'),
            grid: getPadding({ left: 50, right: 15, top: 30, bottom: 70 }),
            legend: {
              show: true,
              data: groupedBarSeries.map((series) => ({
                name: series.name,
                icon: 'circle',
                itemStyle: { color: series.color },
              })),
              top: 0,
            },
          }
        case ChartType.multiLine:
          const multiLineSeries = getGroupedSeries(groups, data.points, 'line')
          return {
            ...globalSharedOptions,
            ...groupedSharedOptions,
            series: multiLineSeries,
            grid: getPadding({ left: 35, right: 10, top: 30, bottom: 70 }),
            legend: {
              show: true,
              top: 0,
              data: multiLineSeries.map((series) => ({
                name: series.name,
                icon: 'circle',
                itemStyle: { color: series.color },
              })),
              selected: selectedGroup,
            },
            animation: true,
            animationDuration: 2000,
            animationEasing: 'linear',
          }
      }
    }

    if (data instanceof SankeyChartData) {
      // Construct nodes array from unique sources and targets
      const nodes = Array.from(new Set(data.points.flatMap((d) => [d.source, d.target]))).map(
        (name) => ({ name }),
      )

      // Construct links array directly from data points
      const links = data.points.map((d, index) => ({
        source: d.source,
        target: d.target,
        value: d.value,
        lineStyle: {
          color: getColorForKey(d.source, index, nodes.length),
        },
      }))

      return {
        ...globalSharedOptions,
        tooltip: {
          trigger: 'item',
          formatter: (params: any) => {
            if (params.dataType === 'edge') {
              return params.data.source
            } else {
              return params.name
            }
          },
        },
        series: [
          {
            type: 'sankey',
            layout: 'none',
            roam: false,
            draggable: false,
            data: nodes,
            nodeGap: 12,
            links: links.map((link) => ({
              ...link,
              lineStyle: {
                ...link.lineStyle,
                opacity: 0.2,
              },
              emphasis: {
                disabled: true,
              },
            })),
            nodes: nodes.map((node) => {
              const isTargetNode = links.some((link) => link.target === node.name)
              return {
                ...node,
                emphasis: {
                  disabled: isTargetNode,
                },
              }
            }),
            itemStyle: {
              normal: {
                borderWidth: 1,
                borderColor: '#aaa',
              },
            },
            lineStyle: {
              normal: {
                curveness: 0.7,
              },
            },
            label: {
              normal: {
                show: true,
                textStyle: {
                  color: '#333', // Text color
                  borderColor: 'transparent', // Remove outline
                },
              },
            },
          },
        ],
      }
    }
  }

  const options = getOptions()
  return (
    <div style={{ width: '100%', height: '100%' }}>
      <MemoizedReactECharts
        ref={echartRef}
        option={options}
        style={{ width: '100%', height: '100%' }}
        onEvents={{
          click: (params: { name: any }) => handleKeyClick(params.name),
          legendselectchanged: (params: { name: string; selected: { [key: string]: boolean } }) => {
            handleLegendClick(params.name, params.selected[params.name])
          },
        }}
      />
    </div>
  )
}
