## Sid Meier's Civilization 4
## 
## This file is part of the UnitUpgradesPediaMod by Vovan
##

import string

from CvPythonExtensions import *

from base import DiGraph
import search
import paths

# globals
gc = CyGlobalContext()
ArtFileMgr = CyArtFileMgr()

class UnitUpgradesGraph:
	"The graph of unit upgrades"
	
	def __init__(self, main):

		self.top = main

		self.iMaximumPredecessors = 3
		self.iMaximumNeighbors = 3
	
	################## Stuff to generate Unit Upgrade Graph ##################
	
	def getActivePlayer(self):
		"Gets the id of the active player for UU upgrades"
		return gc.getGame().getActivePlayer()
	
	def getUnitNumber(self, k):
		"Gets the id for the kth unit for the current active player"
		if (self.getActivePlayer() == -1):
			result = gc.getUnitClassInfo(k).getDefaultUnitIndex()
		else:
			result = gc.getCivilizationInfo(gc.getGame().getActiveCivilizationType()).getCivilizationUnits(k)
		return result
	
	def getUnitName(self, e):
		"Returns the name of the units with the specified id"
		return gc.getUnitInfo(e).getDescription()
	
	def getAllUpgrades(self, eUnit):
		"Returns the ids of all units that the unit passed in upgrades to"
		result = []
		for k in range(gc.getNumUnitClassInfos()):
			eLoopUnit = self.getUnitNumber(k)
			if (eLoopUnit >= 0 and gc.getUnitInfo(eUnit).getUpgradeUnitClass(k)):
				result.append(eLoopUnit)
		
		return result
	
	def addUnitUpgradesToGraph(self, theGraph):
		"Goes through all the units and adds upgrade paths to the graph"
		for k in range(gc.getNumUnitClassInfos()):
			eLoopUnit = self.getUnitNumber(k)
			upgrades = self.getAllUpgrades(eLoopUnit)
			for eUpgradeUnit in upgrades:
				theGraph.add_edge(eLoopUnit, eUpgradeUnit)
		return theGraph
	
	
	################# Stuff to topologically sort the graph ##################
	
	def assignOrderToNodes(self, theGraph, node, map, order):
		if map.has_key(node):
			return map
		else:
			map[node] = order
			for n in theGraph.successors(node):
				map = self.assignOrderToNodes(theGraph, n, map, order + 1)
			for n in theGraph.predecessors(node):
				map = self.assignOrderToNodes(theGraph, n, map, order - 1)
			return map
	
	def extremeOrder(self, map, bSmallest):
		result = 0
		for v in map.values():
			if bSmallest:
				result = min(result, v)
			else:
				result = max(result, v)
		return result
	
	def orderNodes(self, map):
		base = self.extremeOrder(map, True)
		end = self.extremeOrder(map, False)
		length = end - base + 1
		result = []
		for i in range(length):
			result.append([])
		for unit in map.keys():
			ui = map[unit] - base
			result[ui].append(unit)
		return result
	
	def hasNodes(self, layer):
		for u in layer:
			if u > -1:
				return True
		return False
	
	def expandLargeNodes(self, theGraph, layers):
		"Basically, if a node has many predecessors, puts it in a layer of its own."
		result = []
		for layerIndex in range(len(layers)):
			layer = layers[layerIndex]
			cln = []
			for unitIndex in range(len(layer)):
				if (len(theGraph.predecessors(layer[unitIndex])) > self.iMaximumPredecessors):
					nl = len(layer) * [-1]
					nl[unitIndex] = layer[unitIndex]
					result.append(nl)
					cln.append(-1)
				else:
					if theGraph.degree(layer[unitIndex]) > self.iMaximumNeighbors:
						cln.append(-1)
					cln.append(layer[unitIndex])
			if self.hasNodes(cln):
				result.append(cln)
		return result
	
	def compressToLayers(self, theGraph):
		return self.expandLargeNodes(theGraph,
			self.orderNodes(
				self.assignOrderToNodes(theGraph, theGraph.nodes()[0], {}, 0)))
	
	def getGraph(self):
		return search.connected_component_subgraphs(self.addUnitUpgradesToGraph(DiGraph()))
	
	
	################## Stuff to lay out the graph in space ###################
	
	def maximumLayerSize(self, layers):
		result = -1
		for l in layers:
			result = max(result, len(l))
		return result
	
	def calculateLayerHeight(self, units, unitHeight, margin):
		return units * unitHeight + (units - 1) * margin
	
	def layoutLayers(self, graph, verticalOffset, unitWidth, unitHeight, horizontalMargin, verticalMargin):
		#"Lays out the layers of the graph and returns a list of lists of positions corresponding to the entries in the layers."
		layers = self.compressToLayers(graph)
		result = []
		graphHeight = self.calculateLayerHeight(self.maximumLayerSize(layers), unitHeight, verticalMargin)
		xPosition = 0
		for layer in layers:
			p = []
			layerHeight = self.calculateLayerHeight(len(layer), unitHeight, verticalMargin)
			yPosition = verticalOffset + (graphHeight - layerHeight) / 2
			for unit in layer:
				p.append((unit, [xPosition, yPosition]))
				yPosition += unitHeight + verticalMargin
			result.append(p)
			xPosition += unitWidth + horizontalMargin
		return result
	
	
	####################### Stuff to draw graph arrows #######################
	
	def drawGraphArrows(self, pediaScreen, graph, layers):
		self.pediaScreen = pediaScreen
		
		for li in range(len(layers) - 1, -1, -1):
			layer = layers[li]
			arrowCount = self.arrowsInLayer(graph, layer)
			startArrow = 0
			for unit,position in layer:
				if unit > -1:
					startArrow = self.drawUnitArrows(graph, layers, unit, position, li, startArrow, arrowCount)
		return
	
	def arrowsInLayer(self, graph, layer):
		result = 0
		for unit,pos in layer:
			if (unit > -1):
				result += len(graph.predecessors(unit))
		return result
	
	def drawUnitArrows(self, graph, layers, unit, position, layerIndex, startArrow, arrowCount):
		unitArrowCount = len(graph.predecessors(unit))
		unitArrowIndex = 0
		for li in range(layerIndex + 1):
			layer = layers[li]
			for otherUnit,otherPosition in layer:
				if otherUnit > -1:
					if graph.successors(otherUnit).count(unit) > 0:
						startArrow += 1
						unitArrowIndex += 1
						self.drawArrow(otherPosition, position, startArrow, arrowCount + 1, unitArrowIndex, unitArrowCount + 1, layerIndex - li)
		return startArrow
	
	def drawArrow(self, posFrom, posTo, layerArrowIndex, layerArrowCount, unitArrowIndex, unitArrowCount, ldiff):
		# NOTE: TAKES "UpgradesList" as literal...

		screen = self.pediaScreen.getScreen()

		UpgradesList = self.top.UPGRADES_LIST
		
		ARROW_X = ArtFileMgr.getInterfaceArtInfo("ARROW_X").getPath()
		ARROW_Y = ArtFileMgr.getInterfaceArtInfo("ARROW_Y").getPath()
		ARROW_MXMY = ArtFileMgr.getInterfaceArtInfo("ARROW_MXMY").getPath()
		ARROW_XY = ArtFileMgr.getInterfaceArtInfo("ARROW_XY").getPath()
		ARROW_MXY = ArtFileMgr.getInterfaceArtInfo("ARROW_MXY").getPath()
		ARROW_XMY = ArtFileMgr.getInterfaceArtInfo("ARROW_XMY").getPath()
		ARROW_HEAD = ArtFileMgr.getInterfaceArtInfo("ARROW_HEAD").getPath()
		
		xFrom = posFrom[0] + self.pediaScreen.BUTTON_SIZE
		xTo = posTo[0]
		yFrom = posFrom[1] + (self.pediaScreen.BUTTON_SIZE / 2)
		yTo = posTo[1] + (unitArrowIndex * self.pediaScreen.BUTTON_SIZE / unitArrowCount)
		
		xDiff = xTo - xFrom
		yDiff = yTo - yFrom
		
		layerArrowIndex = layerArrowCount - layerArrowIndex
		
		if ldiff > 1:
			pWidth = xDiff / ldiff - self.pediaScreen.BUTTON_SIZE
			sWidth = xDiff - pWidth
		else:
			pWidth = xDiff
			sWidth = 0
		
		if (ldiff < 1):
			x = posFrom[0] + (self.pediaScreen.BUTTON_SIZE / 2)
			if (yDiff < 0):
				y1 = posFrom[1]
				y2 = posTo[1] + self.pediaScreen.BUTTON_SIZE
			else:
				y1 = posFrom[1] + self.pediaScreen.BUTTON_SIZE
				y2 = posTo[1]
			h = y2 - y1
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_Y, x, y1, 8, h, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_HEAD, x, y2, 8, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
		elif (yDiff == 0): # or (abs(yDiff) < (self.pediaScreen.BUTTON_SIZE / 3)):
			x1 = xFrom
			x2 = xTo
			y = yTo
			width = sWidth + pWidth
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_X, x1, y, width, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_HEAD, x2, y, 8, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
		elif (yDiff < 0):
			# Go right:
			x1 = xFrom
			y1 = yFrom
			w1 = sWidth + layerArrowIndex * pWidth / layerArrowCount
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_X, x1, y1, w1, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			
			#Go up:
			x2 = x1 + w1
			y2 = yTo
			h1 = y1 - y2
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_XY, x2, y1, 8, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_Y, x2, y2 + 8, 8, h1 - 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			
			#Go right:
			x3 = x2 + 8
			w2 = xTo - x3
			x4 = x3 + w2
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_XMY, x2, y2, 8, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_X, x3, y2, w2, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_HEAD, x4, y2, 8, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
		elif (yDiff > 0):
			# Go right:
			x1 = xFrom
			y1 = yFrom
			w1 = sWidth + layerArrowIndex * pWidth / layerArrowCount
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_X, x1, y1, w1, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			
			#Go down:
			x2 = x1 + w1
			y2 = yTo
			h1 = y2 - y1
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_MXMY, x2, y1, 8, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_Y, x2, y1 + 8, 8, h1 - 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			
			#Go right:
			y2 = y1 + h1
			x3 = x2 + 8
			w2 = xTo - x3
			x4 = x3 + w2
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_MXY, x2, y2, 8, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_X, x3, y2, w2, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
			screen.addDDSGFCAt( self.pediaScreen.getNextWidgetName(), UpgradesList, ARROW_HEAD, x4, y2, 8, 8, WidgetTypes.WIDGET_GENERAL, -1, -1, False )
		return
	
