import java.util.LinkedList;
import java.util.List;
import java.util.NoSuchElementException;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;


/**
 * An implementation of the MapADT using a hash table with chaining properties
 */
public class HashTableMap<KeyType, ValueType> implements MapADT<KeyType, ValueType> {

    /**
     * Inner class for key value pairs
     */
    protected class Pair {
        public KeyType key;
        public ValueType value;

        public Pair(KeyType key, ValueType value) {
            this.key = key;
            this.value = value;
        }
    }

    // Underlying array for hash table
    protected LinkedList<Pair>[] table = null;
    private int size = 0;

    /**
     * Initializes a new hash table with a capacity
     *
     * @param capacity the initial capacity of the hash table
     * @throws IllegalArgumentException if capacity less than 1
     */
    @SuppressWarnings("unchecked")
    public HashTableMap(int capacity) {
        if (capacity < 1) {
            throw new IllegalArgumentException("Capacity less than 1");
        }
        // Instantiate with raw type and cast to include generic
        this.table = (LinkedList<Pair>[]) new LinkedList[capacity];
    }


    /**
     * Initializes a new hash table with default capacity of 8
     */
    public HashTableMap() {
        this(8);
    }

    /**
     * Helper method to calculate index for a key
     * @param key the key to hash
     * @return the index in the table
     */
    private int getIndex(KeyType key) {
        return Math.abs(key.hashCode()) % table.length;
    }

    @Override
    public void put(KeyType key, ValueType value) throws IllegalArgumentException {
        if (key == null) {
            throw new NullPointerException("Key cannot be null");
        }
        if (containsKey(key)) {
            throw new IllegalArgumentException("Key is already in the map");
        }

        int index = getIndex(key);
        // Initialize chain at this index if doesnt exist
        if (table[index] == null) {
            table[index] = new LinkedList<Pair>();
        }

        table[index].add(new Pair(key, value));
        size++;

        // Resize check if the load factor >= .75
        if ((double) size / table.length >= 0.75) {
            resizeTable();
        }
    }

    /**
     * Private helper method to increase table capacity and rehash pairs
     */
    @SuppressWarnings("unchecked")
    private void resizeTable() {
        LinkedList<Pair>[] oldTable = table;
        // Double capacity
        table = (LinkedList<Pair>[]) new LinkedList[oldTable.length * 2];
        size = 0; // reset size for put()-- will be incremented during rehashing

        for (int i = 0; i < oldTable.length; i++) {
            if (oldTable[i] != null) {
                for (Pair pair : oldTable[i]) {
                    this.put(pair.key, pair.value);
                }
            }
        }
    }

    @Override
    public boolean containsKey(KeyType key) {
        if  (key == null) {
            throw new NullPointerException("Key cannot be null");
        }
        int index = getIndex(key);
        if (table[index] != null) {
            for(Pair pair : table[index]) {
                if (pair.key.equals(key)) {
                    return true;
                }
            }
        }
        return false;
    }

    @Override
    public ValueType get(KeyType key) throws NoSuchElementException {
        if (key == null) {
            throw new NullPointerException("Key cannot be null");
        }
        int index = getIndex(key);
        if (table[index] != null) {
            for (Pair pair : table[index]) {
                if (pair.key.equals(key)) {
                    return pair.value;
                }
            }
        }
        throw new NoSuchElementException("Key not found in table");
    }

    @Override
    public ValueType remove(KeyType key) throws NoSuchElementException {
        if (key == null) {
            throw new NullPointerException("Key cannot be null");
        }
        int index = getIndex(key);
        if (table[index] != null) {
            for (int i = 0; i < table[index].size(); i++) {
                Pair pair = table[index].get(i);
                if (pair.key.equals(key)) {
                    ValueType removedValue = pair.value;
                    table[index].remove(i);
                    size--;
                    return removedValue;
                }
            }
        }
        throw new NoSuchElementException("Key not found in table, can't remove.");
    }

    @Override
    public int getCapacity() {
        return table.length;
    }

    @Override
    public List<KeyType> getKeys() {
        List<KeyType> keyList = new LinkedList<KeyType>();
        for (int i = 0; i < table.length; i++) {
            if (table[i] != null) {
                for (Pair pair : table[i]) {
                    keyList.add(pair.key);
                }
            }
        }
        return keyList;
    }

    @Override
    public void clear() {
        for (int i = 0; i < table.length; i++) {
            table[i] = null;
        }
        size = 0;
    }

    @Override
    public int getSize() {
        return size;
    }



    // JUnit Tests below this line
    // _______________________________________________

    /**
     * Tests basic functionality of put, get, and getSize. Checks that a single key value
     * pair can be added and retrieved correctly.
     */
    @Test
    public void test1BasicPutGet() {
        HashTableMap<String, Integer> map = new HashTableMap<>(10);
        map.put("One", 1);
        map.put("Two", 2);

        // Check that size has updated correctly
        assertEquals(2, map.getSize(), "Size should be 2 after 2 puts");
        // Check if retrieval works
        assertEquals(1, map.get("One"), "Retrieving 'One' should return 1.");
        assertEquals(2, map.get("Two"), "Retrieving 'Two' should return 2.");
    }

    /**
     * Tests collision handling
     * Adds multiple keys that hash to sam index and makes sure all can be accessed
     */
    @Test
    public void test2Collisions() {
        // Initial Capacity 8. Hash codes are the numbers themseles
        // 0 and 8 both result in index 0
        HashTableMap<Integer, String> map = new HashTableMap<>(8);
        map.put(0, "Zero");
        map.put(8, "Eight");

        // Both should exist at same index in different nodes of LinkedList
        assertTrue(map.containsKey(0), "Map should contain key 0");
        assertTrue(map.containsKey(8), "Map should contain key 8");
        assertEquals("Eight", map.get(8), "Retrieving 8 should work despite collision");
    }

    /**
     * Tests the removal of elements and the clear method
     * Ensures that removed keys are done so properly and that table size is reset after clear
     */
    @Test
    public void test3RemoveAndClear() {
        HashTableMap<String, String> map = new HashTableMap<>(5);
        map.put("A", "Apple");
        map.put("B", "Banana");

        // Test removal
        String removed = map.remove("A");
        assertEquals("Apple", removed, "Remove should return the value of the removed key");
        assertFalse(map.containsKey("A"), "Key A should no longer exist");
        assertEquals(1, map.getSize(), "Size should be 1 after removal");

        // Test clear
        map.clear();
        assertEquals(0, map.getSize(), "Size should be 0 after clear");
        assertThrows(NoSuchElementException.class, () -> map.get("B"), "Trying to get a key after clear should throw exception");
    }

    /**
     * Tests the resize trigger logic based on load factor
     * and verifies that capacity doubles exactly when load factor is at 75%
     */
    @Test
    public void test4ResizeTrigger() {
        // Capacity 10 - Threshold is 7.5
        // Resize should trigger when size is 8
        HashTableMap<Integer, Integer> map = new HashTableMap<>(10);

        // Add 7 elements
        for (int i = 0; i < 7; i++) {
            map.put(i, i);
        }

        assertEquals(10, map.getCapacity(), "Capacity should be 10 still at 7 size");

        // Add the 8th element to trigger resize
        map.put(7, 7);
        assertEquals(20, map.getCapacity(), "Capacity should double to 20 when size reaches 8");
    }

    /**
     * Tests rehashing correctness during resize
     * Ensures that ll keys are still accessible and that getKeys
     * returns the correct list after resize
     */
    @Test
    public void test5RehashCorrectness() {
        HashTableMap<Integer, Integer> map = new HashTableMap<>(4);
        map.put(1, 10);
        map.put(2, 20);
        map.put(3, 30);
        // Size 3, capacity 4 -- should trigger resize now

        assertEquals(8, map.getCapacity(), "Capacity should have doubled to 8");

        // Ensure all keys are still retrievable and rehashed correctly to new indices
        List<Integer> keys = map.getKeys();
        assertEquals(3, keys.size(), "getKeys should return 3 keys");
        assertTrue(keys.contains(1) && keys.contains(2) && keys.contains(3), "All original keys should be present");
        assertEquals(20, map.get(2), "Value for key 2 should still be 20 after rehashing");
    }
}
