

import { CharRaster, IRaster, Raster } from "../../drawing/raster";
import { IRasterPoint, RasterPoint } from "../../geometry/raster-point";
import { RasterBox } from "../../geometry/raster-box";
import Heap from "heap";

export enum Side { Left, Right, Top, Bottom }
const allSides = [Side.Left, Side.Right, Side.Top, Side.Bottom]

class GraphNode {
    x: number
    y: number
    side: Side
    neighbours: GraphNode[]
    isBlocked: boolean
    isVisited: boolean

    distanceFromStart: number
    turnCount: number
    predecessor: GraphNode
}

function compareNodes(a: GraphNode, b: GraphNode): number {
    return compareNodeDist(a, b.distanceFromStart, b.turnCount)
}

function compareNodeDist(a: GraphNode, bDist: number, bTurnCount: number): number {
    if (a.distanceFromStart !== bDist) {
        return a.distanceFromStart - bDist
    } else if (a.turnCount !== bTurnCount) {
        return a.turnCount - bTurnCount
    } else {
        return 0
    }
}

function makeNodeKey(x: number, y: number, side: Side): string {
    return x + "/" + y + "/" + side
}

function getNode(nodes: Map<string, GraphNode>, x: number, y: number, side: Side) {
    return nodes.get(makeNodeKey(x, y, side))
}

export class ShortestPathFinder {

    private nodeMap = new Map<string, GraphNode>()
    private allowedStartDirections = new Set<Side>(allSides)
    private allowedEndDirections = new Set<Side>(allSides)

    constructor(private raster: IRaster<string>, private startPos: IRasterPoint, private viaPos: IRasterPoint, private endPos: IRasterPoint) {}

    setAllowedStartDirections(directories: Side[]) {
        this.allowedStartDirections = new Set<Side>(directories)
    }

    setAllowedEndDirections(directories: Side[]) {
        this.allowedEndDirections = new Set<Side>(directories)
    }

    private buildGraph(raster: IRaster<string>) {
        let bounds = new RasterBox(this.startPos, this.endPos)
            .getUnitedWith(new RasterBox(this.viaPos, this.viaPos))
        if (!Raster.isEmpty(raster)) {
            const rasterBounds = Raster.getBounds(raster)
            bounds = bounds.getUnitedWith(rasterBounds)
        }
        bounds = bounds.getExpandedBy(2, 2, 2, 2);
        const nodes = this.nodeMap;

        this.createNodes(bounds, nodes);
        this.connectNodes(bounds, nodes);
    }

    private connectNodes(bounds: RasterBox, nodes: Map<string, GraphNode>) {
        for (let x = bounds.minX(); x <= bounds.maxX(); ++x) {
            for (let y = bounds.minY(); y <= bounds.maxY(); ++y) {

                const left = getNode(nodes, x, y, Side.Left);
                const right = getNode(nodes, x, y, Side.Right);
                const top = getNode(nodes, x, y, Side.Top);
                const bottom = getNode(nodes, x, y, Side.Bottom);

                left.neighbours = x > bounds.minX() ? [getNode(nodes, x - 1, y, Side.Right), getNode(nodes, x, y, Side.Right), getNode(nodes, x, y, Side.Top), getNode(nodes, x, y, Side.Bottom)] : [getNode(nodes, x, y, Side.Right), getNode(nodes, x, y, Side.Top), getNode(nodes, x, y, Side.Bottom)];

                right.neighbours = x < bounds.maxX() ? [getNode(nodes, x + 1, y, Side.Left), getNode(nodes, x, y, Side.Left), getNode(nodes, x, y, Side.Top), getNode(nodes, x, y, Side.Bottom)] : [getNode(nodes, x, y, Side.Left), getNode(nodes, x, y, Side.Top), getNode(nodes, x, y, Side.Bottom)];

                top.neighbours = y > bounds.minY() ? [getNode(nodes, x, y - 1, Side.Bottom), getNode(nodes, x, y, Side.Right), getNode(nodes, x, y, Side.Left), getNode(nodes, x, y, Side.Bottom)] : [getNode(nodes, x, y, Side.Right), getNode(nodes, x, y, Side.Left), getNode(nodes, x, y, Side.Bottom)];

                bottom.neighbours = y < bounds.maxY() ? [getNode(nodes, x, y + 1, Side.Top), getNode(nodes, x, y, Side.Left), getNode(nodes, x, y, Side.Top), getNode(nodes, x, y, Side.Right)] : [getNode(nodes, x, y, Side.Left), getNode(nodes, x, y, Side.Top), getNode(nodes, x, y, Side.Right)];
            }
        }
    }

    private createNodes(bounds: RasterBox, nodes: Map<string, GraphNode>) {
        for (let x = bounds.minX(); x <= bounds.maxX(); ++x) {
            for (let y = bounds.minY(); y <= bounds.maxY(); ++y) {
                for (const side of allSides) {
                    const nodeKey = makeNodeKey(x, y, side)
                    const node = new GraphNode()
                    node.x = x
                    node.y = y
                    node.side = side
                    nodes.set(nodeKey, node)
                }
            }
        }
    }

    private initNodes() {
        for (const node of this.nodeMap.values()) {
            node.distanceFromStart = RasterPoint.equals(this.startPos, { x: node.x, y: node.y }) ? 0 : Number.POSITIVE_INFINITY;
            node.turnCount = RasterPoint.equals(this.startPos, { x: node.x, y: node.y }) ? 0 : Number.POSITIVE_INFINITY;
            node.predecessor = undefined;
            node.isBlocked = !CharRaster.isVisiblyEmptyAt(this.raster, node.x, node.y);
            node.isVisited = false;
        }
    }

    findShortestPath(): IRasterPoint[] {
        this.buildGraph(this.raster);
        this.initNodes();

        let heap = new Heap<GraphNode>(compareNodes)
        for (const d of this.allowedStartDirections) {
            const node = getNode(this.nodeMap, this.startPos.x, this.startPos.y, d)
            heap.push(node)
            node.isVisited = true
        }

        this.runDijkstra(heap)

        // Prepare for second run from viaPos to endPos, retaining and blocking paths to all viaPos nodes.
        for (const node of this.nodeMap.values()) {
            node.isVisited = false
        }
        forEachNodeInChain(getNode(this.nodeMap, this.viaPos.x, this.viaPos.y, Side.Left), node => node.isVisited = true)
        forEachNodeInChain(getNode(this.nodeMap, this.viaPos.x, this.viaPos.y, Side.Right), node => node.isVisited = true)
        forEachNodeInChain(getNode(this.nodeMap, this.viaPos.x, this.viaPos.y, Side.Top), node => node.isVisited = true)
        forEachNodeInChain(getNode(this.nodeMap, this.viaPos.x, this.viaPos.y, Side.Bottom), node => node.isVisited = true)
        for (const node of this.nodeMap.values()) {
            if(!node.isVisited)
            {
                node.distanceFromStart = Number.POSITIVE_INFINITY
                node.turnCount = 0
                node.predecessor = undefined
            }
            node.isVisited = false
        }

        heap = new Heap<GraphNode>(compareNodes)
        for (const d of allSides) {
            const node = getNode(this.nodeMap, this.viaPos.x, this.viaPos.y, d)
            heap.push(node)
            node.isVisited = true
        }

        this.runDijkstra(heap);

        const bestPath: GraphNode[] = []
        const endNode = [...this.allowedEndDirections]
            .map(d => getNode(this.nodeMap, this.endPos.x, this.endPos.y, d))
            .sort(compareNodes)[0]

        //console.log("dist: "+endNode?.distanceFromStart+" turn cost:"+endNode?.turnCount)

        let pathNode = endNode
        while (pathNode) {
            bestPath.push(pathNode)
            pathNode = pathNode.predecessor
        }

        bestPath.reverse()

        if (bestPath.length > 0 && RasterPoint.equals(bestPath[0], this.startPos)) {
            const bestCellPath: IRasterPoint[] = []
            let currentPos = RasterPoint.xy(Number.POSITIVE_INFINITY, Number.POSITIVE_INFINITY)
            for (const pathNode of bestPath) {
                const pathPos = RasterPoint.xy(pathNode.x, pathNode.y)
                if (!RasterPoint.equals(currentPos, pathPos)) {
                    currentPos = pathPos
                    bestCellPath.push(pathPos)
                }
            }
            return bestCellPath
        } else {
            return []
        }
    }

    private runDijkstra(heap: Heap<GraphNode>) {
        while (heap.size() > 0) {
            const currentNode = heap.pop()
            currentNode.isVisited = true

            for (const nextNode of currentNode.neighbours) {
                if (nextNode.isBlocked || nextNode.isVisited) {
                    continue
                }
                const nextDist = currentNode.distanceFromStart + getDistBetweenNeighbours(currentNode, nextNode)
                const nextTurnCost = currentNode.turnCount + getTurnCostBetweeNeighbours(currentNode, nextNode)
                const isImprovement = compareNodeDist(nextNode, nextDist, nextTurnCost) > 0
                if (isImprovement) {
                    nextNode.distanceFromStart = nextDist
                    nextNode.turnCount = nextTurnCost
                    nextNode.predecessor = currentNode
                    heap.push(nextNode)
                }
            }
        }
    }

}

function getDistBetweenNeighbours(a: GraphNode, b: GraphNode) {
    if (a.x === b.x && a.y === b.y) {
        return 0
    } else if (a.y !== b.y) {
        return 2
    } else {
        return 1
    }
}

function getTurnCostBetweeNeighbours(a: GraphNode, b: GraphNode) {
    if (a.x === b.x && a.y === b.y) {
        if (a.side === Side.Left && b.side === Side.Right) {
            return 0
        } else if (a.side === Side.Right && b.side === Side.Left) {
            return 0
        } else if (a.side === Side.Top && b.side === Side.Bottom) {
            return 0
        } else if (a.side === Side.Bottom && b.side === Side.Top) {
            return 0
        } else {
            return 1
        }
    } else {
        return 0
    }
}

function forEachNodeInChain(node: GraphNode, function_: (GraphNode) => void) {
    while (node) {
        function_(node)
        node = node.predecessor
    }
}