// === CS400 File Header Information ===
// Name: <HANNAH WANG>
// Email: <jwang2766@wisc.edu>
// Group and Team: <your group name: two letters, and team color>
// Group TA: <name of your group's ta>
// Lecturer: <Gary>
// Notes to Grader: <optional extra notes>

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.LinkedList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.PriorityQueue;
import org.junit.jupiter.api.Test;

/**
 * This class extends the BaseGraph data structure with additional methods for computing the total
 * cost and list of node data along the shortest path connecting a provided starting to ending
 * nodes. This class makes use of Dijkstra's shortest path algorithm.
 */
public class DijkstraGraph<NodeType, EdgeType extends Number> extends BaseGraph<NodeType, EdgeType>
    implements GraphADT<NodeType, EdgeType> {

  /**
   * While searching for the shortest path between two nodes, a SearchNode contains data about one
   * specific path between the start node and another node in the graph. The final node in this path
   * is stored in its node field. The total cost of this path is stored in its cost field. And the
   * predecessor SearchNode within this path is referened by the predecessor field (this field is
   * null within the SearchNode containing the starting node in its node field).
   *
   * SearchNodes are Comparable and are sorted by cost so that the lowest cost SearchNode has the
   * highest priority within a java.util.PriorityQueue.
   */
  protected class SearchNode implements Comparable<SearchNode> {
    public Node node;
    public double cost;
    public SearchNode predecessor;

    public SearchNode(Node node, double cost, SearchNode predecessor) {
      this.node = node;
      this.cost = cost;
      this.predecessor = predecessor;
    }

    public int compareTo(SearchNode other) {
      if (cost > other.cost)
        return +1;
      if (cost < other.cost)
        return -1;
      return 0;
    }
  }

  /**
   * Constructor that sets the map that the graph uses.
   */
  public DijkstraGraph() {
    super(new HashtableMap<>());
  }

  /**
   * This helper method creates a network of SearchNodes while computing the shortest path between
   * the provided start and end locations. The SearchNode that is returned by this method is
   * represents the end of the shortest path that is found: it's cost is the cost of that shortest
   * path, and the nodes linked together through predecessor references represent all of the nodes
   * along that shortest path (ordered from end to start).
   *
   * @param start the data item in the starting node for the path
   * @param end   the data item in the destination node for the path
   * @return SearchNode for the final end node within the shortest path
   * @throws NoSuchElementException when no path from start to end is found or when either start or
   *                                end data do not correspond to a graph node
   */
  protected SearchNode computeShortestPath(NodeType start, NodeType end) {
    // if start/end data do not correspond to a graph node
    if (!this.containsNode(start) || !this.containsNode(end)) {
      throw new NoSuchElementException();// implement in step 5.3
    }
    // set up pq and visited set
    PriorityQueue<SearchNode> pq = new PriorityQueue<>();
    HashtableMap<Node, Node> visited = new HashtableMap<>();
    Node startNode = nodes.get(start);
    Node endNode = nodes.get(end);
    // add the start node to the pq
    pq.add(new SearchNode(startNode, 0.0, null));
    // searching for next node
    while (!pq.isEmpty()) {
      SearchNode curr = pq.poll();
      // the node is visited
      if (visited.containsKey(curr.node))
        continue;
      visited.put(curr.node, curr.node);
      // find the end node
      if (curr.node.equals(endNode)) {
        return curr;
      }
      // update the current node and edge to its neighbors
      for (Edge edge : curr.node.edgesLeaving) {
        Node neighbor = edge.successor;
        if (!visited.containsKey(neighbor)) {
          double newCost = curr.cost + edge.data.doubleValue();
          pq.add(new SearchNode(neighbor, newCost, curr));
        }
      }
    }
    // no such path from start to end
    throw new NoSuchElementException();
  }

  /**
   * Returns the list of data values from nodes along the shortest path from the node with the
   * provided start value through the node with the provided end value. This list of data values
   * starts with the start value, ends with the end value, and contains intermediary values in the
   * order they are encountered while traversing this shorteset path. This method uses Dijkstra's
   * shortest path algorithm to find this solution.
   *
   * @param start the data item in the starting node for the path
   * @param end   the data item in the destination node for the path
   * @return list of data item from node along this shortest path
   */
  public List<NodeType> shortestPathData(NodeType start, NodeType end) {
    // implement in step 5.4
    // get the endNode
    SearchNode endNode = computeShortestPath(start, end);
    // create a path
    LinkedList<NodeType> path = new LinkedList<>();
    // start from the end node and backtrack
    while (endNode != null) {
      path.addFirst(endNode.node.data);
      endNode = endNode.predecessor;
    }

    return path;
  }

  /**
   * Returns the cost of the path (sum over edge weights) of the shortest path freom the node
   * containing the start data to the node containing the end data. This method uses Dijkstra's
   * shortest path algorithm to find this solution.
   *
   * @param start the data item in the starting node for the path
   * @param end   the data item in the destination node for the path
   * @return the cost of the shortest path between these nodes
   */
  public double shortestPathCost(NodeType start, NodeType end) {
    // implement in step 5.4
    return computeShortestPath(start, end).cost;
  }

  // TODO: implement 3+ tests in step 4.1
  @Test
  public void test1() {
    DijkstraGraph<String, Integer> graph = new DijkstraGraph<>();
    graph.insertNode("A");
    graph.insertNode("B");
    graph.insertNode("M");
    graph.insertNode("I");
    graph.insertNode("E");
    graph.insertNode("D");
    graph.insertNode("F");
    graph.insertNode("G");
    graph.insertNode("H");
    graph.insertNode("L");
    graph.insertEdge("A", "B", 1);
    graph.insertEdge("A", "M", 5);
    graph.insertEdge("A", "H", 7);
    graph.insertEdge("B", "M", 3);
    graph.insertEdge("M", "I", 4);
    graph.insertEdge("M", "E", 3);
    graph.insertEdge("M", "F", 4);
    graph.insertEdge("I", "H", 2);
    graph.insertEdge("I", "D", 1);
    graph.insertEdge("D", "A", 7);
    graph.insertEdge("D", "F", 4);
    graph.insertEdge("D", "G", 2);
    graph.insertEdge("F", "G", 9);
    graph.insertEdge("G", "A", 4);
    graph.insertEdge("G", "L", 7);
    graph.insertEdge("G", "H", 9);
    graph.insertEdge("H", "I", 2);
    graph.insertEdge("H", "B", 6);
    graph.insertEdge("H", "L", 2);
    assertEquals(13, graph.shortestPathCost("D", "I"));
    assertEquals(List.of("D", "G", "H", "I"), graph.shortestPathData("D", "I"));
  }

  @Test
  public void test2() {

    DijkstraGraph<String, Integer> graph = new DijkstraGraph<>();
    graph.insertNode("A");
    graph.insertNode("B");
    graph.insertNode("M");
    graph.insertNode("I");
    graph.insertNode("E");
    graph.insertNode("D");
    graph.insertNode("F");
    graph.insertNode("G");
    graph.insertNode("H");
    graph.insertNode("L");
    graph.insertEdge("A", "B", 1);
    graph.insertEdge("A", "M", 5);
    graph.insertEdge("A", "H", 7);
    graph.insertEdge("B", "M", 3);
    graph.insertEdge("M", "I", 4);
    graph.insertEdge("M", "E", 3);
    graph.insertEdge("M", "F", 4);
    graph.insertEdge("I", "H", 2);
    graph.insertEdge("I", "D", 1);
    graph.insertEdge("D", "A", 7);
    graph.insertEdge("D", "F", 4);
    graph.insertEdge("D", "G", 2);
    graph.insertEdge("F", "G", 9);
    graph.insertEdge("G", "A", 4);
    graph.insertEdge("G", "L", 7);
    graph.insertEdge("G", "H", 9);
    graph.insertEdge("H", "I", 2);
    graph.insertEdge("H", "B", 6);
    graph.insertEdge("H", "L", 2);
    assertEquals(7, graph.shortestPathCost("A", "H"));
    assertEquals(List.of("A", "H"), graph.shortestPathData("A", "H"));


  }

  @Test
  public void test3() {
    DijkstraGraph<String, Integer> graph = new DijkstraGraph<>();
    graph.insertNode("A");
    graph.insertNode("B");
    graph.insertNode("M");
    graph.insertNode("I");
    graph.insertNode("E");
    graph.insertNode("D");
    graph.insertNode("F");
    graph.insertNode("G");
    graph.insertNode("H");
    graph.insertNode("L");
    graph.insertEdge("A", "B", 1);
    graph.insertEdge("A", "M", 5);
    graph.insertEdge("A", "H", 7);
    graph.insertEdge("B", "M", 3);
    graph.insertEdge("M", "I", 4);
    graph.insertEdge("M", "E", 3);
    graph.insertEdge("M", "F", 4);
    graph.insertEdge("I", "H", 2);
    graph.insertEdge("I", "D", 1);
    graph.insertEdge("D", "A", 7);
    graph.insertEdge("D", "F", 4);
    graph.insertEdge("D", "G", 2);
    graph.insertEdge("F", "G", 9);
    graph.insertEdge("G", "A", 4);
    graph.insertEdge("G", "L", 7);
    graph.insertEdge("G", "H", 9);
    graph.insertEdge("H", "I", 2);
    graph.insertEdge("H", "B", 6);
    graph.insertEdge("H", "L", 2);
    assertThrows(NoSuchElementException.class, () -> graph.shortestPathData("L", "H"));
    assertThrows(NoSuchElementException.class, () -> graph.shortestPathCost("L", "H"));
  }

}
