use progvis:lang:cpp;
use progvis:lang:cpp:impl;
use lang:asm;
use core:lang;
use core:asm;
use lang:bs:macro;

/**
 * Note: This was originally in the progvis.lang.cpp package, but since the patching logic will
 * reload the Ptr class many times, it was moved here to speed up the system. In particular, we need
 * the copy constructor.
 */

// Pointers.
Ptr : generate(params) {
	generatePtr(params, "Ptr", false, false, false);
}

ConstPtr : generate(params) {
	generatePtr(params, "ConstPtr", true, false, false);
}

// References.
Ref : generate(params) {
	generatePtr(params, "Ref", false, true, false);
}

ConstRef : generate(params) {
	generatePtr(params, "ConstRef", true, true, false);
}

// RValue-ref.
RRef : generate(params) {
	generatePtr(params, "RRef", false, true, true);
}

ConstRRef : generate(params) {
	generatePtr(params, "ConstRRef", true, true, true);
}


// Create a type.
private Named? generatePtr(Array<Value> params, Str name, Bool isConst, Bool isRef, Bool rvalRef) {
	// Only one parameter.
	if (params.count != 1)
		return null;

	// It must be a value-type (this includes other Ptr classes).
	var par = params[0];
	if (!par.isValue)
		return null;

	PtrType(par.asRef(false), name, isConst, isRef, rvalRef);
}


void assumeSameAlloc(unsafe:RawPtr a, unsafe:RawPtr b) {
	Bool same;
	asm {
		mov ptrA, a;
		mov ptrB, b;
		cmp p[ptrA], p[ptrB];
		setCond same, ifEqual;
	}

	if (!same)
		throw PtrError("undefined behavior", "Trying to compare pointers from different allocations with <, >, <=, =>, or -");
}

// Check so that a pointer is not deallocated.
void checkPtr(unsafe:RawPtr base) {
	if ((base.readFilled() & AllocFlags:sizeMask.v) == 0)
		throw PtrError("use after free", "Trying to read from memory that was freed.");
}

// Check the validity of a pointer. Assumes we want to read a maximum of 'size' bytes at wherever
// 'ptr' and 'offset' refers to.
void checkPtr(unsafe:RawPtr base, Nat offset, Nat size) {
	Nat total = base.readSize() * (base.readFilled() & AllocFlags:sizeMask.v);
	// Array header.
	if (base.isValue)
		offset -= sPtr.current * 2;
	if (offset + size > total) {
		if (total == 0)
			throw PtrError("use after free", "Trying to read from memory that was freed.");
		else
			throw PtrError("buffer overflow", "Trying to read at offset ${offset} in an allocation of size ${total}.");
	}
}

// Check that the pointer provided as 'base' and 'offset' refers to the start of an allocation, and
// that it was actually allocated on the heap.
// Returns 'false' if the pointer was a null pointer.
Bool checkDelete(unsafe:RawPtr base, Nat offset) {
	if (base.empty() & offset == 0)
		return false;

	if (offset != sPtr.current * 2)
		throw PtrError("memory", "Trying to delete memory not allocated by 'new'!");

	Nat filled = base.readFilled;

	// Malloc'd memory is marked with the MSB set.
	if ((filled & AllocFlags:heapAlloc.v) == 0)
		throw PtrError("memory", "Trying to delete memory allocated on the stack!");

	if ((filled & AllocFlags:arrayAlloc.v) != 0)
		throw PtrError("memory", "This allocation was allocated using new[], and should be freed using delete[].");

	if ((filled & AllocFlags:sizeMask.v) == 0)
		throw PtrError("use after free", "Trying to free memory that was already freed.");

	true;
}

// Check that this pointer was allocated using 'new[]' for arrays. Returns the number of elements.
Bool checkDeleteArray(unsafe:RawPtr base, Nat offset) {
	if (base.empty() & offset == 0)
		return false;

	if (offset != sPtr.current * 2)
		throw PtrError("memory", "Trying to delete memory not allocated by 'new[]'!");

	Nat filled = base.readFilled;

	// Malloc'd memory is marked with the MSB set.
	if ((filled & AllocFlags:heapAlloc.v) == 0)
		throw PtrError("memory", "Trying to delete memory allocated on the stack!");

	if ((filled & AllocFlags:sizeMask.v) == 0)
		throw PtrError("use after free", "Trying to free memory that was already freed.");

	// To make 'free' feasible, we don't complain that you need to use the plain delete if the plain new was used.
	return true;
}


/**
 * Various helper functions for the pointer class.
 */

// Compare two pointers. Don't call from Storm.
Bool pointerEq(unsafe:RawPtr a, unsafe:RawPtr b) {
	Bool r = false;
	asm {
		mov ptrA, a;
		mov ptrB, b;
		cmp p[ptrA], p[ptrB];
		jmp ifNotEqual, @done;

		cmp i[ptrA + sPtr], i[ptrB + sPtr];
		setCond r, ifEqual;
	done:
	}
	r;
}
Bool pointerNeq(unsafe:RawPtr a, unsafe:RawPtr b) {
	Bool r = true;
	asm {
		mov ptrA, a;
		mov ptrB, b;
		cmp p[ptrA], p[ptrB];
		jmp ifNotEqual, @done;

		cmp i[ptrA + sPtr], i[ptrB + sPtr];
		setCond r, ifNotEqual;
	done:
	}
	r;
}
Bool pointerLt(unsafe:RawPtr a, unsafe:RawPtr b) {
	assumeSameAlloc(a, b);

	Bool r;
	asm {
		mov ptrA, a;
		mov ptrB, b;
		cmp p[ptrA + sPtr], p[ptrB + sPtr];
		setCond r, ifBelow;
	}
	r;
}
Bool pointerGt(unsafe:RawPtr a, unsafe:RawPtr b) {
	assumeSameAlloc(a, b);

	Bool r;
	asm {
		mov ptrA, a;
		mov ptrB, b;
		cmp p[ptrA + sPtr], p[ptrB + sPtr];
		setCond r, ifAbove;
	}
	r;
}
Bool pointerLte(unsafe:RawPtr a, unsafe:RawPtr b) {
	assumeSameAlloc(a, b);

	Bool r;
	asm {
		mov ptrA, a;
		mov ptrB, b;
		cmp p[ptrA + sPtr], p[ptrB + sPtr];
		setCond r, ifBelowEqual;
	}
	r;
}
Bool pointerGte(unsafe:RawPtr a, unsafe:RawPtr b) {
	assumeSameAlloc(a, b);

	Bool r;
	asm {
		mov ptrA, a;
		mov ptrB, b;
		cmp p[ptrA + sPtr], p[ptrB + sPtr];
		setCond r, ifAboveEqual;
	}
	r;
}
