import math
import random
import numpy as np
import torch
import pandas as pd
from functools import partial

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from e3nn.o3 import matrix_z, rand_matrix

from model import TFNModel, HEGNNModel


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_k_fold(fold, rot_mod):
    num_nodes = fold + 1

    x = torch.zeros(num_nodes, 1, dtype=torch.float32)

    base = torch.tensor([1.0, 0.0, 0.0])
    pos_list = [torch.tensor([0.0, 0.0, 0.0]), base]

    for k in range(1, fold):
        angle = 2 * math.pi * k / fold
        R = matrix_z(torch.tensor([angle])).squeeze(0)
        pos_list.append(base @ R.T)

    pos = torch.stack(pos_list)

    edge_index = torch.LongTensor([[0] * fold, list(range(1, fold + 1))])
    edge_attr = torch.zeros(fold, 1, dtype=torch.float32)

    if rot_mod == '2d':
        q = 2 * math.pi / (fold + random.randint(1, fold))
        Q = matrix_z(torch.tensor([q])).squeeze(0)
        
    else: 
        Q = rand_matrix()

    pos_rot = pos @ Q.T

    d1 = Data(
        fold=str(fold),
        rot_mod=rot_mod,
        x=x,
        pos=pos,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    d2 = Data(
        fold=str(fold),
        rot_mod=rot_mod,
        x=x,
        pos=pos_rot,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return [d1, d2]


def get_dataset():
    data_list = []
    for fold in [2, 3, 4, 6]:
        for rot_mod in ["2d", "3d"]:
            data_list.extend(get_k_fold(fold, rot_mod))
    return DataLoader(data_list, batch_size=1024, shuffle=False)


def get_model(name, channels, num_layers):
    table = {
        "TFN": TFNModel,
        "HEGNN": HEGNNModel,
    }
    ModelClass = table[name]
    return ModelClass(
        max_ell=11,
        num_layer=num_layers,
        hidden_dim=64,
        irreps_channels=channels,
        node_input_dim=1,
        edge_input_dim=1,
    )


def degree_norm(vec, channels, max_ell=11):
    norms = []
    offset = 0
    for ell in range(max_ell + 1):
        block_dim = channels * (2 * ell + 1)
        block = vec[offset : offset + block_dim].view(channels, 2 * ell + 1)
        offset += block_dim

        val = torch.linalg.norm(block, dim=1).mean().item()
        norms.append(val)
    return norms


def get_emb_csv(path="./kfold_emb_degree_norm.csv"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loader = get_dataset()

    rows = []

    model_types = ["HEGNN", "TFN"]
    layer_list = [1, 2, 3, 4]
    channel_list = [1, 4, 16]

    for model_type in model_types:
        for num_layers in layer_list:
            for channels in channel_list:
                model = get_model(model_type, channels, num_layers).to(device).eval()

                with torch.no_grad():
                    for batch in loader:
                        batch = batch.to(device)

                        emb = model(batch)
                        if emb.ndim == 1:
                            emb = emb.unsqueeze(0)

                        even = emb[0::2].cpu()
                        odd  = emb[1::2].cpu()
                        diff = even - odd

                        folds = batch.fold[0::2]
                        rotmods = batch.rot_mod[0::2]

                        for i in range(diff.size(0)):
                            deg = degree_norm(diff[i], channels)

                            row = {
                                "model_type": model_type,
                                "num_layer": num_layers,
                                "irreps_channels": channels,
                                "fold": folds[i],
                                "rot_mod": rotmods[i],
                            }

                            for ell in range(12):
                                row[f"deg{ell}"] = deg[ell]

                            rows.append(row)

    df = pd.DataFrame(rows)
    df.to_csv(path, index=False)
    print(f"saved {len(df)} rows to {path}")
    return df


if __name__ == "__main__":
    set_seed(0)
    get_emb_csv()
