import copy
import os
import sys
from pathlib import Path
# Add the root directory of the project to the sys.path to allow imports
current_file_dir = os.path.dirname(os.path.abspath(__file__))
rootpath = os.path.abspath(os.path.join(current_file_dir, '../'))
sys.path.append(rootpath)

# Now you can import the module
from nfn_moe.common.weight_space import MoEWeightSpaceFeatures

import random

import numpy as np
import torch
from model_vision import VisionTransformer
from torch import nn

from nfn_moe.common.weight_space import (
    MoEWeightSpaceFeatures,
    moe_network_spec_from_wsfeat,
)
from nfn_moe.layers.layers import MoELinearEquiv, MoELinearInv
from nfn_moe.layers.misc_layers import TupleOpMoE


from check_group_action import sample_group_action

def set_seed(manualSeed=3):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)



def apply_group_action_to_wsfeat(wsfeat, group_actions):
    g_wsfeat = {}

    # Initialize empty lists for each transformed weight in the output
    for key in ["W_q", "W_k", "W_v", "W_o", "W_G", "W_A", "W_B", "b_G", "b_A", "b_B"]:
        g_wsfeat[key] = []

    # Iterate through each layer and group action
    for layer_idx, group_action in enumerate(group_actions):
        S_h, M_k, M_v , S_G, gamma_W, gamma_b, Pi_e = group_action["S_h"], group_action["M_k"], group_action["M_v"], group_action["S_G"], group_action["gamma_W"], group_action["gamma_b"], group_action["Pi_e"]
        # Process each key in wsfeat
        if "W_q" in wsfeat.__dict__:
            transformed_W_q = wsfeat.W_q[layer_idx][:,:,S_h] @ M_k[S_h].transpose(-1, -2)
            g_wsfeat["W_q"].append(transformed_W_q)

        if "W_k" in wsfeat.__dict__:
            transformed_W_k = wsfeat.W_k[layer_idx][:,:,S_h] @ torch.inverse(M_k[S_h])
            g_wsfeat["W_k"].append(transformed_W_k)

        if "W_v" in wsfeat.__dict__:
            transformed_W_v = wsfeat.W_v[layer_idx][:,:,S_h] @ M_v[S_h]
            g_wsfeat["W_v"].append(transformed_W_v)

        if "W_o" in wsfeat.__dict__:
            transformed_W_o = torch.inverse(M_v[S_h]) @ wsfeat.W_o[layer_idx][:,:,S_h]
            g_wsfeat["W_o"].append(transformed_W_o)

        if "W_G" in wsfeat.__dict__:
            transformed_W_G = wsfeat.W_G[layer_idx][:,:,S_G] + gamma_W
            g_wsfeat["W_G"].append(transformed_W_G)
        
        if "b_G" in wsfeat.__dict__:
            transformed_b_G = wsfeat.b_G[layer_idx][:,:,S_G] + gamma_b
            g_wsfeat["b_G"].append(transformed_b_G)

        if "W_A" in wsfeat.__dict__:
            transformed_W_A = []
            for i in range(wsfeat.W_A[layer_idx].shape[2]):
                transformed_W_A.append(wsfeat.W_A[layer_idx][:,:,S_G][:, :, i][:, :, :, Pi_e[S_G][i]])
            g_wsfeat["W_A"].append(torch.stack(transformed_W_A, dim=2))
        
        if "b_A" in wsfeat.__dict__:
            transformed_b_A = []
            for i in range(wsfeat.b_A[layer_idx].shape[2]):
                transformed_b_A.append(wsfeat.b_A[layer_idx][:,:,S_G][:, :, i][:, :, Pi_e[S_G][i]])
            g_wsfeat["b_A"].append(torch.stack(transformed_b_A, dim=2))

        if "W_B" in wsfeat.__dict__:
            transformed_W_B = []
            for i in range(wsfeat.W_B[layer_idx].shape[2]):
                transformed_W_B.append(wsfeat.W_B[layer_idx][:,:,S_G][:, :, i][:, :, Pi_e[S_G][i], :])
            g_wsfeat["W_B"].append(torch.stack(transformed_W_B, dim=2))
        
        if "b_B" in wsfeat.__dict__:
            transformed_b_B = wsfeat.b_B[layer_idx][:,:,S_G]
            g_wsfeat["b_B"].append(transformed_b_B)
    
    # Convert lists back to tensors
    for key in g_wsfeat.keys():
        g_wsfeat[key] = torch.stack(g_wsfeat[key], dim=0)  # Shape: [n_layers, ...]

    return MoEWeightSpaceFeatures(**g_wsfeat)


def check_params_eq(params1: MoEWeightSpaceFeatures, params2: MoEWeightSpaceFeatures):
    weight_keys = ["W_q", "W_k", "W_v", "W_o", "W_G", "W_A", "W_B"]
    bias_keys = ["b_G", "b_A", "b_B"]

    # Compare weights
    equal = True
    for key in weight_keys:
        weight1 = getattr(params1, key)
        weight2 = getattr(params2, key)

        for w1, w2 in zip(weight1, weight2):
            if not torch.allclose(w1, w2, atol=1e-2, rtol=1e-2, equal_nan=True):
                print(f"Mismatch found in {key}, diff {torch.abs(w1-w2).max()}")
                equal = False

    # Compare biases
    for key in bias_keys:
        bias1 = getattr(params1, key)
        bias2 = getattr(params2, key)

        for b1, b2 in zip(bias1, bias2):
            if not torch.allclose(b1, b2, atol=1e-2, rtol=1e-2, equal_nan=True):
                print(f"Mismatch found in {key}, diff {torch.abs(w1-w2).max()}")
                equal = False

    return equal


def test_layer_equivariance_group_action():
    # Define model and data parameters
    embed_dim = 32 # 32
    n_layers = 4
    n_heads = 4
    n_experts = 4
    forward_mul = 2
    image_size = 28
    n_channels = 1
    patch_size = 4
    bsz = 1  # The batch size (number of models)
    D_k = embed_dim // n_heads
    D_v = embed_dim // n_heads
    D_A = embed_dim * forward_mul
    D = embed_dim

    channel_in = 1
    channel_out = 1

    ws_dict = {
        "W_q": torch.randn(n_layers, bsz, channel_in, n_heads, embed_dim, D_k, dtype=torch.float64, device="cuda"),  # Shape: [n_layers, b, channel_in, h, D, D_q]
        "W_k": torch.randn(n_layers, bsz, channel_in, n_heads, embed_dim, D_k, dtype=torch.float64, device="cuda"),  # Shape: [n_layers, b, channel_in, h, D, D_k]
        "W_v": torch.randn(n_layers, bsz, channel_in, n_heads, embed_dim, D_v, dtype=torch.float64, device="cuda"),  # Shape: [n_layers, b, channel_in, h, D, D_v]
        "W_o": torch.randn(n_layers, bsz, channel_in, n_heads, D_v, embed_dim, dtype=torch.float64, device="cuda"),  # Shape: [n_layers, b, channel_in, h, D_v, D]
        "W_G": torch.randn(n_layers, bsz, channel_in, n_experts, embed_dim, dtype=torch.float64, device="cuda"),     # Shape: [n_layers, b, channel_in, n_experts, D]
        "W_A": torch.randn(n_layers, bsz, channel_in, n_experts, embed_dim, D_A, dtype=torch.float64, device="cuda"),           # Shape: [n_layers, bsz, channel_in, n_experts, D, D_A]
        "W_B": torch.randn(n_layers, bsz, channel_in, n_experts, D_A, embed_dim, dtype=torch.float64, device="cuda"),           # Shape: [n_layers, bsz, channel_in, n_experts, D_A, D]
        "b_G": torch.randn(n_layers, bsz, channel_in, n_experts, dtype=torch.float64, device="cuda"),                      # Shape: [n_layers, bsz, channel_in, n_experts]
        "b_A": torch.randn(n_layers, bsz, channel_in, n_experts, D_A, dtype=torch.float64, device="cuda"),                      # Shape: [n_layers, bsz, channel_in, n_experts, D_A]
        "b_B": torch.randn(n_layers, bsz, channel_in, n_experts, embed_dim, dtype=torch.float64, device="cuda")                 # Shape: [n_layers, bsz, channel_in, n_experts, D]
    }

    # Display the shapes to verify the transformation
    for key, tensor in ws_dict.items():
        print(f"{key} shape: {tensor.shape}")

    wsfeat = MoEWeightSpaceFeatures(**ws_dict)

    encoder_weight_spec =  moe_network_spec_from_wsfeat(wsfeat)

    actual_D, actual_D_q, actual_D_k, actual_D_v, n_e, actual_D_A, _ = encoder_weight_spec.get_all_dims()

    # Compare actual dimensions to expected dimensions
    assert actual_D == D, f"Expected D={D}, but got {actual_D}"
    assert actual_D_q == D_k, f"Expected D_q={D_k}, but got {actual_D_k}"
    assert actual_D_k == D_k, f"Expected D_k={D_k}, but got {actual_D_k}"
    assert actual_D_v == D_v, f"Expected D_v={D_v}, but got {actual_D_v}"
    assert actual_D_A == D_A, f"Expected D_A={D_A}, but got {actual_D_A}"

    nfn = nn.Sequential(
        # TupleOpMoE(nn.ReLU()),
        # TupleOpMoE(nn.ReLU(), masked_features=['W_q, W_k', 'W_v', 'W_o']),
        MoELinearEquiv(encoder_weight_spec, channel_in, channel_out),
        TupleOpMoE(nn.ReLU(), masked_features=['W_q', 'W_k', 'W_v', 'W_o', 'W_G', 'b_G']),
        # TupleOpMoE(nn.ReLU()),
        # TransformersLinear(encoder_weight_spec, channel_out, channel_out),
        # TupleOpMoE(nn.ReLU()),
        # TupleOpMoE(nn.ReLU(), masked_features=['W_q, W_k', 'W_v', 'W_o']),
    ).cuda().double()
    out = nfn(wsfeat)

    for _ in range(20):
        group_actions = [sample_group_action(n_heads, n_experts, D_k, D_v, D_A, D) for _ in range(n_layers)]

        wsfeat1 = apply_group_action_to_wsfeat(wsfeat, group_actions)
        gE = apply_group_action_to_wsfeat(out, group_actions)
        Eg = nfn(wsfeat1)
        print(check_params_eq(Eg,gE ))


def test_layer_invariant_group_action():
    # Define model and data parameters
    embed_dim = 32 # 32
    n_layers = 4
    n_heads = 2
    n_experts = 4
    forward_mul = 2
    image_size = 28
    n_channels = 1
    patch_size = 4
    bsz = 1  # The batch size (number of models)
    D_k = embed_dim // n_heads
    D_v = embed_dim // n_heads
    D_A = embed_dim * forward_mul
    D = embed_dim

    channel_in = 2
    channel_out = 2

    ws_dict = {
        "W_q": torch.randn(n_layers, bsz, channel_in, n_heads, embed_dim, D_k, dtype=torch.float64, device="cuda"),  # Shape: [n_layers, b, channel_in, h, D, D_q]
        "W_k": torch.randn(n_layers, bsz, channel_in, n_heads, embed_dim, D_k, dtype=torch.float64, device="cuda"),  # Shape: [n_layers, b, channel_in, h, D, D_k]
        "W_v": torch.randn(n_layers, bsz, channel_in, n_heads, embed_dim, D_v, dtype=torch.float64, device="cuda"),  # Shape: [n_layers, b, channel_in, h, D, D_v]
        "W_o": torch.randn(n_layers, bsz, channel_in, n_heads, D_v, embed_dim, dtype=torch.float64, device="cuda"),  # Shape: [n_layers, b, channel_in, h, D_v, D]
        "W_G": torch.randn(n_layers, bsz, channel_in, n_experts, embed_dim, dtype=torch.float64, device="cuda"),     # Shape: [n_layers, b, channel_in, n_experts, D]
        "W_A": torch.randn(n_layers, bsz, channel_in, n_experts, embed_dim, D_A, dtype=torch.float64, device="cuda"),           # Shape: [n_layers, bsz, channel_in, n_experts, D, D_A]
        "W_B": torch.randn(n_layers, bsz, channel_in, n_experts, D_A, embed_dim, dtype=torch.float64, device="cuda"),           # Shape: [n_layers, bsz, channel_in, n_experts, D_A, D]
        "b_G": torch.randn(n_layers, bsz, channel_in, n_experts, dtype=torch.float64, device="cuda"),                      # Shape: [n_layers, bsz, channel_in, n_experts]
        "b_A": torch.randn(n_layers, bsz, channel_in, n_experts, D_A, dtype=torch.float64, device="cuda"),                      # Shape: [n_layers, bsz, channel_in, n_experts, D_A]
        "b_B": torch.randn(n_layers, bsz, channel_in, n_experts, embed_dim, dtype=torch.float64, device="cuda")                 # Shape: [n_layers, bsz, channel_in, n_experts, D]
    }

    # Display the shapes to verify the transformation
    for key, tensor in ws_dict.items():
        print(f"{key} shape: {tensor.shape}")

    wsfeat = MoEWeightSpaceFeatures(**ws_dict)

    encoder_weight_spec =  moe_network_spec_from_wsfeat(wsfeat)

    actual_D, actual_D_q, actual_D_k, actual_D_v, n_e, actual_D_A, _ = encoder_weight_spec.get_all_dims()

    # Compare actual dimensions to expected dimensions
    assert actual_D == D, f"Expected D={D}, but got {actual_D}"
    assert actual_D_q == D_k, f"Expected D_q={D_k}, but got {actual_D_k}"
    assert actual_D_k == D_k, f"Expected D_k={D_k}, but got {actual_D_k}"
    assert actual_D_v == D_v, f"Expected D_v={D_v}, but got {actual_D_v}"
    assert actual_D_A == D_A, f"Expected D_A={D_A}, but got {actual_D_A}"

    nfn = nn.Sequential(
        # TupleOpMoE(nn.ReLU()),
        # TupleOpMoE(nn.ReLU(), masked_features=['W_q, W_k', 'W_v', 'W_o']),
        MoELinearInv(encoder_weight_spec, channel_in, channel_out),
        # TupleOpMoE(nn.ReLU(), masked_features=['W_q', 'W_k', 'W_v', 'W_o', 'W_G', 'b_G']),
        # TupleOpMoE(nn.ReLU()),
        # TransformersLinear(encoder_weight_spec, channel_out, channel_out),
        # TupleOpMoE(nn.ReLU()),
        # TupleOpMoE(nn.ReLU(), masked_features=['W_q, W_k', 'W_v', 'W_o']),
    ).cuda().double()
    out = nfn(wsfeat)

    for _ in range(1):
        group_actions = [sample_group_action(n_heads, n_experts, D_k, D_v, D_A, D) for _ in range(n_layers)]

        wsfeat1 = apply_group_action_to_wsfeat(wsfeat, group_actions)
        # gE = apply_group_action_to_wsfeat(out, group_actions)
        out1 = nfn(wsfeat1)
        print(out)
        print(out1)
        inv = torch.allclose(out, out1, atol=1e-3, rtol=1e-3, equal_nan=True)
        print(inv)

        # out = nfn(wsfeat)

        # group_actions = [sample_group_action(n_heads, D_k, D_v, D_A, D) for _ in range(n_layers)]

        # wsfeat1 = apply_group_action_to_wsfeat(wsfeat, group_actions)
        # out1 = nfn(wsfeat1)


        # inv = torch.allclose(out, out1, atol=1e-3, rtol=1e-3, equal_nan=True)


if __name__ == "__main__":
    # set_seed(6)
    test_layer_equivariance_group_action()
    # test_layer_invariant_group_action()
