///////////////////////////////////////////
///
/// MIDWEEK SUBMISSION WITH INSERT METHOD
/// File: RedBlackTree.java
/// Author: Evan Gaul
/// Email: eagaul@wisc.edu
/// Date: 9/28/25
/// Assignment: P104.RedBlackTree
/// Course: CS400 Lec004 Fall 2025
/// Professor: Ashley Samuelson
///
///////////////////////////////////////////
///
/// Assistance: None
///
///////////////////////////////////////////

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;

/**
 * RedBlackTree.java
 * This class implements a Red-Black Tree, specifically insertion
 */
public class RedBlackTree<T extends Comparable<T>> extends BSTRotation<T>{

    public RedBlackTree() {
        super();
    }

    /**
     * Checks if a new red node in the RedBlackTree causes a red property violation
     * by having a red parent. If this is not the case, the method terminates without
     * making any changes to the tree. If a red property violation is detected, then
     * the method repairs this violation and any additional red property violations
     * that are generated as a result of the applied repair operation.
     * Using this method might cause nodes with a value equal to the value of one of
     * their ancestors to appear within the left and the right subtree of that ancestor,
     * even if the original insertion procedure consistently inserts such nodes into only
     * the left or the right subtree. But it will preserve the ordering of nodes within
     * the tree.
     * @param newNode a newly inserted red node, or a node turned red by previous repair
     */
    protected void ensureRedProperty(RedBlackNode<T> newNode) {
        /**
         * Cases:
         * 1. Parent red, aunt red
         *      Recolor parent and aunt to black, grandparent to red, move to grandparent and start checking
         * 2. Parent red, aunt black, zigzag (on inside)
         *      Rotate parent and and child and then do case 3
         * 3. Parent red, aunt red, with child as the parent's outer child
         *      Rotate grandparent and parent in opposite direction and swap colors of parent and grandparent
         *
         * Make sure root is always black
         */
        if (newNode == root || newNode == null) { // make sure root is black, base case as well
            newNode.isBlackNode = true;
            return;
        }
        RedBlackNode<T> parent = newNode.getParent();
        if (parent.isBlackNode) return; // New Node has black parent, all good

        RedBlackNode<T> grandparent = parent.getParent();
        if (grandparent == null) { // no grandparent, make sure parent is black because that is root now
            parent.isBlackNode = true; // parent becomes root
            return;
        }

        // Find aunt
        RedBlackNode<T> aunt;
        if (grandparent.getLeft() == parent) {
            aunt = grandparent.getRight();
        } else {
            aunt = grandparent.getLeft();
        }

        // Case 1, aunt red
        if (aunt != null && !aunt.isBlackNode) { // aunt is there and is red
            // Just recolor and recursively go up to grandparent to check
            parent.isBlackNode = true;
            aunt.isBlackNode = true;
            grandparent.isBlackNode = false;
            ensureRedProperty(grandparent);
            return;
        }

        // Case 2/3, aunt black or null
        if (parent == grandparent.getLeft()) {
            if (newNode == parent.getRight()) { // Zig zag
                rotate(newNode, parent);
                newNode = newNode.getLeft(); // update newNode
                parent = newNode.getParent();
            }
            // Left and left case, rotate right on grandparent
            parent.isBlackNode = true;
            grandparent.isBlackNode = false;
            rotate(parent, grandparent);
        } else {
            if (newNode == parent.getLeft()) {
                // Zigzag, rotate
                rotate(newNode, parent);
                newNode = newNode.getRight();
                parent = newNode.getParent();
            }
            // Right and right case, rotate left on grandparent
            parent.isBlackNode = true;
            grandparent.isBlackNode = false;
            rotate(parent, grandparent);
        }
    }

    /**
     * Insets a new node into the RedBlack Tree
     * Makes the new node red and calls ensureRedProperty()
     */
    @Override
    public void insert(T data) {
        RedBlackNode<T> newNode = new RedBlackNode<>(data);

        if (this.root == null) {
            // First node in the tree, make it root and color it black
            this.root = newNode;
            newNode.isBlackNode = true;
        } else {
            // Insert like a regular BST
            insertHelper(newNode, this.root);
            ensureRedProperty(newNode);
        }
    }

    ///////////////////////////////////////////////////////////////////////////
    ///                             TEST METHODS!
    ///////////////////////////////////////////////////////////////////////////


    /**
     * Test 1
     * Insert 10, 20, 30
     * After it is balanced, tree should look like:
     *          20B
     *       10R   30R
     *
     */
    @Test
    public void testOne() {
        RedBlackTree<Integer> tree = new RedBlackTree<>();
        tree.insert(10);
        tree.insert(20);
        tree.insert(30);

        RedBlackNode<Integer> root = (RedBlackNode<Integer>) tree.root;
        // Make sure root is correct value and color
        assertEquals(20, root.getData());
        assertTrue(root.isBlackNode());

        // Check other values and their colors
        assertEquals(10, root.getLeft().getData());
        assertFalse(root.getLeft().isBlackNode());
        assertEquals(30, root.getRight().getData());
        assertFalse(root.getRight().isBlackNode());
    }

    /**
     * Test 2
     * ZigZag rotation
     * Insert 10, 30, 20
     * Tree:
     *      10B         10B                 20B
     *        30R  ->      20R      ->    10R  30R
     *       20R               30R
     */
    @Test
    public void testTwo() {
        RedBlackTree<Integer> tree = new RedBlackTree<>();
        tree.insert(10);
        tree.insert(30);
        tree.insert(20);

        RedBlackNode<Integer> root = (RedBlackNode<Integer>) tree.root;
        assertEquals(20, root.getData());
        assertTrue(root.isBlackNode());

        assertEquals(10, root.getLeft().getData());
        assertFalse(root.getLeft().isBlackNode());
        assertEquals(30, root.getRight().getData());
        assertFalse(root.getRight().isBlackNode());

    }

    /**
     * Test 3, Q03 example, Question 1
     * Original tree from question:
     *                  MB
     *               GB    UR
     *            DR  IR  QB  XB
     *                       VR YR
     * The question then inserts Z, after that the tree should look like this:
     *                       UB
     *                   MR      XR
     *                GB   QB  VB   YB
     *              DR  IR            ZR
     */
    @Test
    public void testThree() {
        RedBlackTree<String> tree = new RedBlackTree<>();

        tree.insert("M");
        tree.insert("G");
        tree.insert("U");
        tree.insert("D");
        tree.insert("I");
        tree.insert("Q");
        tree.insert("X");
        tree.insert("V");
        tree.insert("Y");
        // Check whole tree
        RedBlackNode<String> root = (RedBlackNode<String>) tree.root;
        assertEquals("M", root.getData()); // M, black, root
        assertTrue(root.isBlackNode());
        assertEquals("G", root.getLeft().getData()); // G, black
        assertTrue(root.getLeft().isBlackNode());
        assertEquals("U", root.getRight().getData()); // U, red
        assertFalse(root.getRight().isBlackNode());
        assertEquals("D", root.getLeft().getLeft().getData()); // D, red
        assertFalse(root.getLeft().getLeft().isBlackNode());
        assertEquals("I", root.getLeft().getRight().getData()); // I, red
        assertFalse(root.getLeft().getRight().isBlackNode());
        assertEquals("Q", root.getRight().getLeft().getData()); // Q, Black
        assertTrue(root.getRight().getLeft().isBlackNode());
        assertEquals("X", root.getRight().getRight().getData()); // X, black
        assertTrue(root.getRight().getRight().isBlackNode());
        assertEquals("V", root.getRight().getRight().getLeft().getData()); // V, red
        assertFalse(root.getRight().getRight().getLeft().isBlackNode());
        assertEquals("Y", root.getRight().getRight().getRight().getData()); // Y, red
        assertFalse(root.getRight().getRight().getRight().isBlackNode());

        // insert Z now
        tree.insert("Z");

        // check whole tree once more
        RedBlackNode<String> root2 = (RedBlackNode<String>) tree.root;
        assertEquals("U", root2.getData()); // U, black, root
        assertTrue(root2.isBlackNode());
        assertEquals("M", root2.getLeft().getData()); // M, red
        assertFalse(root2.getLeft().isBlackNode());
        assertEquals("X", root2.getRight().getData()); // X, red
        assertFalse(root2.getRight().isBlackNode());
        assertEquals("G", root2.getLeft().getLeft().getData()); // G, black
        assertTrue(root2.getLeft().getLeft().isBlackNode());
        assertEquals("Q", root2.getLeft().getRight().getData()); // Q, black
        assertTrue(root2.getLeft().getRight().isBlackNode());
        assertEquals("V", root2.getRight().getLeft().getData()); // V, Black
        assertTrue(root2.getRight().getLeft().isBlackNode());
        assertEquals("Y", root2.getRight().getRight().getData()); // Y, black
        assertTrue(root2.getRight().getRight().isBlackNode());
        assertEquals("D", root2.getLeft().getLeft().getLeft().getData()); // D, red
        assertFalse(root2.getLeft().getLeft().getLeft().isBlackNode());
        assertEquals("I", root2.getLeft().getLeft().getRight().getData()); // I, red
        assertFalse(root2.getLeft().getLeft().getRight().isBlackNode());
        assertEquals("Z", root2.getRight().getRight().getRight().getData()); // Z, red
        assertFalse(root2.getRight().getRight().getRight().isBlackNode());
    }
}

