from mathutils import Matrix, Quaternion, Vector
from io import StringIO
import bpy, bmesh, os

from . import ska
from . import asciiparser
from . import utils

def fromSkaMesh(lods: list, name: str, removeDoubles: bool, convMat: Matrix) -> list:
    createdObjects = []

    for i, lod in enumerate(lods):
        if (len(lods) > 1):
            lodName = "{}_LOD_{}".format(name, i)
        else:
            lodName = name

        mesh = bpy.data.meshes.new(lodName)

        mesh.from_pydata(lod.vertices, [], [])
        mesh.update()

        obj = bpy.data.objects.new(lodName, mesh)
        createdObjects.append(obj)

        bpy.context.scene.collection.objects.link(obj)
        bpy.context.view_layer.objects.active = obj

        # create empty weight maps
        for wmap in lod.weightMaps:
            g = obj.vertex_groups.new(name=wmap.name)

        bpy.ops.object.mode_set(mode="EDIT")
        bm = bmesh.from_edit_mesh(mesh)

        # create faces, materials and uvmaps
        createUVlayers = True
        bm.verts.ensure_lookup_table()
        for j, surf in enumerate(lod.surfaces):
            mat = bpy.data.materials.get(surf.name, None)
            if (not mat):
                mat = bpy.data.materials.new(surf.name)
            mesh.materials.append(mat)

            for tri in surf.tris:
                try:
                    f = bm.faces.new((bm.verts[tri[0]+surf.firstVtx], bm.verts[tri[1]+surf.firstVtx], bm.verts[tri[2]+surf.firstVtx]))
                except:
                    continue
                f.material_index = j

                for k, uvmap in enumerate(lod.uvmaps[:8]):
                    if (createUVlayers):
                        mesh.uv_layers.new(name=uvmap.name)
                    layer = bm.loops.layers.uv[k]
                    for l in f.loops:
                        uv = uvmap.coords[l.vert.index]
                        l[layer].uv = (uv[0], 1.0 - uv[1])

                createUVlayers = False



        bm.faces.ensure_lookup_table()
        bmesh.update_edit_mesh(mesh)

        bpy.ops.object.mode_set(mode="OBJECT")

        # create weights OLD
        if (not lod.vertexWeights):
            for wmap in lod.weightMaps:
                g = obj.vertex_groups[wmap.name]
                for vtx, val in wmap.values.items():
                    g.add([vtx], val, "ADD")
        else:
            for surf in lod.surfaces:
                for j in range(surf.firstVtx, surf.firstVtx + surf.vtxCount):
                    for idx, wbyte in zip(*lod.vertexWeights[j]):
                        if (not wbyte):
                            continue

                        abs_idx = surf.weightMaps[idx]
                        g = obj.vertex_groups[lod.weightMaps[abs_idx].name]
                        g.add([j], wbyte / 255, "ADD")

        # create shape keys
        if (lod.morphMaps):
            if (not mesh.shape_keys):
                obj.shape_key_add(name="Basis")
                mesh.shape_keys.use_relative = lod.morphMaps[0].relative
            for mmap in lod.morphMaps:
                sk = obj.shape_key_add(name=mmap.name)
                for vtx, val in mmap.values.items():
                    sk.data[vtx].co =  val[:3]

        if (removeDoubles):
            bpy.ops.object.mode_set(mode="EDIT")
            bpy.ops.mesh.remove_doubles()
            bpy.ops.object.mode_set(mode="OBJECT")

        mesh.transform(convMat, shape_keys=True)

    return createdObjects

def fromSkaSkeleton(lods: list, name: str, convMat: Matrix) -> list:
    createdObjects = []

    for i, lod in enumerate(lods):
        if (len(lods) > 1):
            lodName = "{}_ARM_LOD_{}".format(name, i)
        else:
            lodName = "{}_ARM".format(name)

        arm = bpy.data.armatures.new(lodName)
        obj = bpy.data.objects.new(lodName, arm)
        createdObjects.append(obj)

        bpy.context.scene.collection.objects.link(obj)
        bpy.context.view_layer.objects.active = obj

        bpy.ops.object.mode_set(mode="EDIT")

        # create bones
        for sb in lod.bones:
            b = arm.edit_bones.new(sb.name)

            b.head = (0, 0, 0)
            b.tail = (0, 1, 0)

            mat = Matrix(sb.absPlacement+[(0, 0, 0, 1)])
            mat = convMat @ mat @ convMat.inverted()

            if (sb.parentName):
                b.parent = arm.edit_bones[sb.parentName]

            b.matrix = mat
            b.length = max(sb.length, 0.01)

        bpy.ops.object.mode_set(mode="OBJECT")

    return createdObjects

def fromSkaAnimset(anims: list, convMat: Matrix):
    armObj = None
    for obj in bpy.data.objects:
        if (obj.type == "ARMATURE"):
            if (armObj and len(armObj.pose.bones) > len(obj.pose.bones)):
                continue
            armObj = obj
    
    if (armObj):
        if (not armObj.animation_data):
            armObj.animation_data_create()

        bpy.context.view_layer.objects.active = armObj

        bpy.ops.object.mode_set(mode="POSE")

    bpy.context.scene.frame_set(0)

    for anim in anims:
        if (anim.name in bpy.data.actions):
            continue

        if (anim.morphEvps):
            # !!! in blender bone animations and shape keys (morph) animations are treated as different actions !!!
            # if animation has both bone and shape envelopes - create shape keys action as SK_<animation name>
            skActionName = "SK_{}".format(anim.name) if anim.boneEvps else anim.name
            skAction = bpy.data.actions.get(skActionName, None)
            if (not skAction):
                skAction = bpy.data.actions.new(skActionName)
                skAction.id_root = "KEY"

        # animate shape keys
        for me in anim.morphEvps:
            fc = skAction.fcurves.new('key_blocks["{}"].value'.format(me.mapName))
            fc.keyframe_points.add(len(me.factors))
            
            for i, f in enumerate(me.factors):
                fc.keyframe_points[i].co = (i, f)

        if (not anim.boneEvps or not armObj):
            continue

        action = bpy.data.actions.new(anim.name)
        armObj.animation_data.action = action

        for be in anim.boneEvps:
            bone = armObj.pose.bones.get(be.boneName, None)
            if (not bone):
                continue

            defPos = bone.bone.matrix_local
            
            for frm, pos in be.positions.items():
                vec = convMat @ Vector(pos)
                if (bone.parent):
                    vec = bone.parent.bone.matrix_local @ vec

                vec = defPos.inverted() @ vec

                bone.location = vec
                bone.keyframe_insert("location", frame=frm)

            for frm, rot in be.rotations.items():
                mat = (convMat @ Quaternion(rot).to_matrix().to_4x4() @ convMat.inverted())
                if (bone.parent):
                    mat = bone.parent.bone.matrix_local @ mat

                mat = defPos.inverted() @ mat

                quat = mat.to_quaternion()

                bone.rotation_quaternion = quat
                bone.keyframe_insert("rotation_quaternion", frame=frm)

    if (armObj):
        bpy.ops.pose.transforms_clear()
        bpy.ops.object.mode_set(mode="OBJECT")
        armObj.animation_data.action = None

def fromSkaAsciiAnimation(filename: str, convMat: Matrix):
    armObj = None
    for obj in bpy.data.objects:
        if (obj.type == "ARMATURE"):
            if (armObj and len(armObj.pose.bones) > len(obj.pose.bones)):
                continue
            armObj = obj
    
    if (armObj):
        if (not armObj.animation_data):
            armObj.animation_data_create()

        bpy.context.view_layer.objects.active = armObj

        bpy.ops.object.mode_set(mode="POSE")

    bpy.context.scene.frame_set(0)

    with open(filename, "r") as file:
        parser = asciiparser.AsciiParser(StringIO(file.read()))

    ver = parser.getFloat("SE_ANIM")
    if (ver != 0.1):
        raise ValueError("Invalid ascii animation file version {}".format(ver))

    spf = parser.getFloat("SEC_PER_FRAME")
    frameCount = parser.getInt("FRAMES")
    name = parser.getString("ANIM_ID")

    action = None

    for i in parser.iterList("BONEENVELOPES"):
        if (not action):
            action = bpy.data.actions.new(name)

            if (armObj):
                armObj.animation_data.action = action

        boneName = parser.getString("NAME")
        
        parser.expect("DEFAULT_POSE")
        parser.beginBlock()

        defPos = parser.scanf("%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f;")
        defPos = convMat @ Matrix([defPos[:4], defPos[4:8], defPos[8:], (0, 0, 0, 1)]) @ convMat.inverted()

        parser.endBlock()

        parser.beginBlock()
        for j in range(frameCount):
            mat = parser.scanf("%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f;")

            # if no armature just continue to read
            if (not armObj):
                continue

            # armature might not have this bone
            bone = armObj.pose.bones.get(boneName, None)
            if (not bone):
                continue

            mat = convMat @ Matrix([mat[:4], mat[4:8], mat[8:], (0, 0, 0, 1)]) @ convMat.inverted()

            rel = defPos.inverted() @ mat

            pos = rel.translation
            rot = rel.to_quaternion()

            bone.location = pos
            bone.rotation_quaternion = rot
            
            bone.keyframe_insert("location", frame=j)
            bone.keyframe_insert("rotation_quaternion", frame=j)
        parser.endBlock()

    skAction = None

    for i in parser.iterList("MORPHENVELOPES"):
        if (not skAction):
            skAction = bpy.data.actions.new("SK_{}".format(name) if action else name)
            skAction.id_root = "KEY"

        mapName = parser.getString("NAME")

        fc = skAction.fcurves.new('key_blocks["{}"].value'.format(mapName))
        fc.keyframe_points.add(frameCount)

        parser.beginBlock()
        for j in range(frameCount):
            val = parser.scanf("%f;")[0]
            fc.keyframe_points[j].co = (j, val)
        parser.endBlock()

    parser.expect("SE_ANIM_END")

    if (armObj):
        bpy.ops.pose.transforms_clear()
        bpy.ops.object.mode_set(mode="OBJECT")
        armObj.animation_data.action = None

def loadMeshBin(filename: str, removeDoubles: bool, convMat: Matrix) -> list:
    lods = ska.readMeshBin(filename)

    name = os.path.splitext(os.path.basename(filename))[0]

    return fromSkaMesh(lods, name, removeDoubles, convMat)

def loadMeshAscii(filename: str, removeDoubles: bool, convMat: Matrix) -> list:
    lod = ska.readMeshAscii(filename)

    name = os.path.splitext(os.path.basename(filename))[0]

    return fromSkaMesh([lod], name, removeDoubles, convMat)

def loadSkeletonBin(filename: str, convMat: Matrix) -> list:
    lods = ska.readSkeletonBin(filename)

    name = os.path.splitext(os.path.basename(filename))[0]

    return fromSkaSkeleton(lods, name, convMat)

def loadSkeletonAscii(filename: str, convMat: Matrix) -> list:
    lod = ska.readSkeletonAscii(filename)

    name = os.path.splitext(os.path.basename(filename))[0]

    return fromSkaSkeleton([lod], name, convMat)

def loadAnimsetBin(filename: str, convMat: Matrix):
    anims = ska.readAnimsetBin(filename)

    fromSkaAnimset(anims, convMat)

def loadAnimsetAscii(filename: str, convMat: Matrix):
    fromSkaAsciiAnimation(filename, convMat)

def loadModelAscii(filename: str, loadMesh: bool, loadSkeleton: bool, loadAnimset: bool, applyArmature: bool, removeDoubles: bool, convMat: Matrix):
    # need game root since filenames are relative
    root = utils.getGameRoot(filename)

    meshes = []
    skeletons = []
    animsets = []
    with open(filename, "r") as file:
        for line in file.readlines():
            text = line.strip()
            if (text.startswith("MESH")):
                meshes.append(text.split("TFNM")[1][2:-2])
            elif (text.startswith("SKELETON")):
                skeletons.append(text.split("TFNM")[1][2:-2])
            elif (text.startswith("ANIMSET")):
                animsets.append(text.split("TFNM")[1][2:-2])

    objects = []
    armObjects = []
    if (meshes and loadMesh):
        for fn in meshes:
            if (root):
                fp = os.path.join(root, fn)
            else:
                fp = utils.resolveRelativePath(filename, fn)
                if (not fp):
                    print("Could not find mesh at '{}'".format(fn))
                    continue
            objects += loadMeshBin(fp, removeDoubles, convMat)
    if (skeletons and loadSkeleton):
        for fn in skeletons:
            if (root):
                fp = os.path.join(root, fn)
            else:
                fp = utils.resolveRelativePath(filename, fn)
                if (not fp):
                    print("Could not find skeleton at '{}'".format(fn))
                    continue
            armObjects += loadSkeletonBin(fp, convMat)
    if (animsets and loadAnimset):
        for fn in animsets:
            if (root):
                fp = os.path.join(root, fn)
            else:
                fp = utils.resolveRelativePath(filename, fn)
                if (not fp):
                    print("Could not find animation set at '{}'".format(fn))
                    continue
            loadAnimsetBin(fp, convMat)

    # apply armatures if needed
    if (applyArmature and objects and armObjects):
        for obj, armObj in zip(objects, armObjects):
            bpy.ops.object.select_all(action="DESELECT")
            obj.select_set(True)
            armObj.select_set(True)
            bpy.context.view_layer.objects.active = armObj
            bpy.ops.object.parent_set(type="ARMATURE")
        bpy.ops.object.select_all(action="DESELECT")

def loadModelBin(filename: str, loadMesh: bool, loadSkeleton: bool, loadAnimset: bool, applyArmature: bool, removeDoubles: bool, convMat: Matrix):
    meshes, skeletons, animsets = ska.readModelConfigurationBin(filename)

    root = utils.getGameRoot(filename)

    objects = []
    armObjects = []

    if (loadMesh and meshes):
        for fn in meshes:
            if (root):
                fp = os.path.join(root, fn)
            else:
                fp = utils.resolveRelativePath(filename, fn)
                if (not fp):
                    print("Could not find mesh at '{}'".format(fn))
                    continue
            objects += loadMeshBin(fp, removeDoubles, convMat)

    if (loadSkeleton and skeletons):
        for fn in skeletons:
            if (root):
                fp = os.path.join(root, fn)
            else:
                fp = utils.resolveRelativePath(filename, fn)
                if (not fp):
                    print("Could not find skeleton at '{}'".format(fn))
                    continue
            armObjects += loadSkeletonBin(fp, convMat)

    if (loadAnimset and animsets):
        for fn in animsets:
            if (root):
                fp = os.path.join(root, fn)
            else:
                fp = utils.resolveRelativePath(filename, fn)
                if (not fp):
                    print("Could not find animation set at '{}'".format(fn))
                    continue
            loadAnimsetBin(fp, convMat)

    # apply armatures if needed
    if (applyArmature and objects and armObjects):
        for obj, armObj in zip(objects, armObjects):
            bpy.ops.object.select_all(action="DESELECT")
            obj.select_set(True)
            armObj.select_set(True)
            bpy.context.view_layer.objects.active = armObj
            bpy.ops.object.parent_set(type="ARMATURE")
        bpy.ops.object.select_all(action="DESELECT")

def load(files: list, loadMesh: bool, loadSkeleton: bool, loadAnimset: bool, applyArmature: bool, removeDoubles: bool, convMat: Matrix):
    objects = []
    armObjects = []

    for fn in files:
        ext = os.path.splitext(fn)[1].casefold()

        if (ext == ".bm"):
            objects += loadMeshBin(fn, removeDoubles, convMat)
        elif (ext == ".am"):
            objects +=loadMeshAscii(fn, removeDoubles, convMat)
        elif (ext == ".bs"):
            armObjects += loadSkeletonBin(fn, convMat)
        elif (ext == ".as"):
            armObjects += loadSkeletonAscii(fn, convMat)
        elif (ext == ".ba"):
            loadAnimsetBin(fn, convMat)
        elif (ext == ".aa"):
            loadAnimsetAscii(fn, convMat)
        elif (ext == ".bmc"):
            loadModelBin(fn, loadMesh, loadSkeleton, loadAnimset, applyArmature, removeDoubles, convMat)
        elif (ext == ".smc"):
            loadModelAscii(fn, loadMesh, loadSkeleton, loadAnimset, applyArmature, removeDoubles, convMat)

    # apply armatures if needed
    if (applyArmature and objects and armObjects):
        for obj, armObj in zip(objects, armObjects):
            bpy.ops.object.select_all(action="DESELECT")
            obj.select_set(True)
            armObj.select_set(True)
            bpy.context.view_layer.objects.active = armObj
            bpy.ops.object.parent_set(type="ARMATURE")
        bpy.ops.object.select_all(action="DESELECT")
        