from model_vision import VisionTransformer
import copy
import torch
import random
import numpy as np
import os

def set_seed(manualSeed=3):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)

class Checkpoint_WeightSpace_Converter():
    def __init__(self) -> None:
        self.statedict_to_weightspace_names = {"queries": {"weight": "W_q"},
                         "keys": {"weight": "W_k"},
                         "values": {"weight": "W_v"},
                         "out_projection": {"weight": "W_o"},
                         
                         "gate": {"weight": "W_G", "bias": "b_G"},

                         "htoh4": {"weight": "W_A", "bias": "b_A"},
                         "h4toh": {"weight": "W_B", "bias": "b_B"}
                         }

    def convert_state_dict_to_weight_space(self, checkpoint, n_encoder_layer, n_heads):
        encoder = [{} for i in range(n_encoder_layer)]
        all_model_keys = [k.split(".") for k in checkpoint.keys()]

        for key in all_model_keys:
            if key[0] == "encoder" and key[-1] in ["weight", "bias"]:
                layer_shape = checkpoint[".".join(key)].shape
                if key[-2] in ["queries", "keys", "values"]:
                    encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]] = checkpoint[".".join(key)].reshape(n_heads, layer_shape[0]//n_heads, layer_shape[1]).transpose(-1, -2).to(torch.float64)  # Use float64 precision
                elif key[-2] == "out_projection":
                    encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]] = checkpoint[".".join(key)].transpose(0,1).reshape(n_heads, layer_shape[1]//n_heads, layer_shape[0]).to(torch.float64)
                elif key[-2] in ["gate", "htoh4", "h4toh"]:
                    if key[-1] == "weight":
                        if key[-2] == "gate":
                            encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]] = checkpoint[".".join(key)].to(torch.float64)
                        else:
                            encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]] = checkpoint[".".join(key)].transpose(-1, -2).to(torch.float64)
                    if key[-1] == "bias":
                        encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]] = checkpoint[".".join(key)].to(torch.float64)
        return encoder
    
    def convert_weight_space_to_state_dict(self, encoder, checkpoint):
        all_model_keys = [k.split(".") for k in checkpoint.keys()]
        
        for key in all_model_keys:
            if key[0] == "encoder" and key[-1] in ["weight", "bias"]:
                layer_shape = checkpoint[".".join(key)].shape
                if key[-2] in ["queries", "keys", "values"]:
                    checkpoint[".".join(key)] = encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]].transpose(-1, -2).reshape(layer_shape[0], layer_shape[1]).to(torch.float64)
                elif key[-2] == "out_projection":
                    checkpoint[".".join(key)] = encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]].reshape(layer_shape[1], layer_shape[0]).transpose(-1, -2).to(torch.float64)
                elif key[-2] in ["gate", "htoh4", "h4toh"]:
                    if key[-1] == "weight":
                        if key[-2] == "gate":
                            checkpoint[".".join(key)] = encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]].to(torch.float64)
                        else:
                            checkpoint[".".join(key)] = encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]].transpose(-1,-2).to(torch.float64)
                    else:
                        checkpoint[".".join(key)] = encoder[int(key[1])][self.statedict_to_weightspace_names[key[-2]][key[-1]]].to(torch.float64)
        
        return checkpoint

def sample_group_action(n_heads, n_experts, D_k, D_v, D_A, D):
    S_h = torch.randperm(n_heads)
    S_G = torch.randperm(n_experts)
    Pi_e = torch.stack([torch.randperm(D_A, device = 'cuda') for _ in range(n_experts)], dim=0)

    M_k = torch.rand(n_heads, D_k, D_k, device="cuda", dtype=torch.float64)  # Use float64 precision
    M_v = torch.rand(n_heads, D_v, D_v, device="cuda", dtype=torch.float64)  # Use float64 precision
    gamma_W = torch.rand(D, device="cuda", dtype=torch.float64)
    gamma_b = torch.rand(1, device="cuda", dtype=torch.float64)
    
    return {"S_h": S_h, "M_k": M_k, "M_v": M_v, 
            "S_G": S_G, "gamma_W": gamma_W, "gamma_b": gamma_b,
            "Pi_e": Pi_e}

def apply_group_action(encoders, group_actions):
    g_encoders = []
    for encoder, group_action in zip(encoders, group_actions):
        S_h, M_k, M_v , S_G, gamma_W, gamma_b, Pi_e = group_action["S_h"], group_action["M_k"], group_action["M_v"], group_action["S_G"], group_action["gamma_W"], group_action["gamma_b"], group_action["Pi_e"]
        g_encoder = {}
        for key in encoder.keys():
            if key == "W_q":
                g_encoder["W_q"] = encoder["W_q"][S_h] @ M_k[S_h].transpose(-1, -2)
            elif key == "W_k":
                g_encoder["W_k"] = encoder["W_k"][S_h] @ torch.inverse(M_k[S_h])
            elif key == "W_v":
                g_encoder["W_v"] = encoder["W_v"][S_h] @ M_v[S_h]
            elif key == "W_o":
                g_encoder["W_o"] = torch.inverse(M_v[S_h]) @ encoder["W_o"][S_h]
            elif key == "W_G":
                g_encoder["W_G"] = encoder["W_G"][S_G] + gamma_W.unsqueeze(0)
            elif key == "b_G":
                g_encoder["b_G"] = encoder["b_G"][S_G] + gamma_b
            elif key == "W_A":
                g_encoder_list_WA= []
                for i in range(encoder["W_A"].shape[0]):
                    g_encoder_list_WA.append(encoder["W_A"][S_G][i][:, Pi_e[S_G][i]])
                g_encoder_WA = torch.stack(g_encoder_list_WA, dim=0)
                g_encoder["W_A"] = g_encoder_WA
            elif key == "b_A":
                g_encoder_bA = []
                for i in range(encoder["b_A"].shape[0]):
                    g_encoder_bA.append(encoder["b_A"][S_G][i][Pi_e[S_G][i]])
                g_encoder["b_A"] = torch.stack(g_encoder_bA, dim=0)
            elif key == "W_B":
                g_encoder_list_WB = []
                for i in range(encoder["W_B"].shape[0]):
                    g_encoder_list_WB.append(encoder["W_B"][S_G][i][Pi_e[S_G][i], :])
                g_encoder["W_B"] = torch.stack(g_encoder_list_WB, dim=0)
            elif key == "b_B":
                g_encoder["b_B"] = encoder["b_B"][S_G]
            else:
                g_encoder[key] = encoder[key]
        g_encoders.append(g_encoder)
    return g_encoders

def check_group_action():
    embed_dim = 32
    n_layers = 2
    n_heads = 2
    n_experts = 4
    forward_mul = 2
    image_size = 28
    n_channels = 1
    patch_size = 4
    bsz = 1024
    D_k = embed_dim // n_heads
    D_v = embed_dim // n_heads
    D_A = embed_dim * forward_mul
    D = embed_dim

    model_1 = VisionTransformer(n_channels=n_channels, embed_dim=embed_dim, n_layers=n_layers, n_attention_heads=n_heads, 
                                forward_mul=forward_mul, image_size=image_size,
                              patch_size=patch_size, n_classes=10, dropout=0.0).to("cuda").to(torch.float64)  # Use float64 precision
    checkpoint_1 = model_1.state_dict()

    converter = Checkpoint_WeightSpace_Converter()
    encoders_ws = converter.convert_state_dict_to_weight_space(checkpoint_1, n_layers, n_heads)

    checkpoint_2 = copy.deepcopy(model_1.state_dict())
    group_actions = [sample_group_action(n_heads, n_experts, D_k, D_v, D_A, D) for _ in range(n_layers)]
    g_encoders_ws = apply_group_action(encoders_ws, group_actions)

    checkpoint_2 = converter.convert_weight_space_to_state_dict(g_encoders_ws, checkpoint_2)
    model_2 = VisionTransformer(n_channels=n_channels, embed_dim=embed_dim, n_layers=n_layers, n_attention_heads=n_heads, 
                                forward_mul=forward_mul, image_size=image_size,
                              patch_size=patch_size, n_classes=10, dropout=0.0).to("cuda").to(torch.float64)  # Use float64 precision
    model_2.load_state_dict(checkpoint_2)

    batch = torch.rand([bsz, n_channels, image_size, image_size]).to("cuda").to(torch.float64)  # Use float64 precision

    result_1 = model_1(batch)
    result_2 = model_2(batch)
    num_correct = sum([torch.allclose(result_1[idx], result_2[idx], rtol=1e-4, atol=1e-6) for idx in range(result_1.shape[0])])  # Use higher tolerance
    print(f"Group action correct for: {num_correct}/{result_1.shape[0]} samples")

if __name__ == "__main__":
    set_seed(6)
    check_group_action()
