import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from typing import Tuple
from argparse import Namespace
from torch_geometric.transforms import BaseTransform, Compose, AddSelfLoops
from geometric_transforms.num_neighbours import FullyConnected
import wandb
import numpy as np
import random
import os


class OneHotTransform(BaseTransform):
    def __init__(self, num_atoms, num_bonds):
        self.num_atoms = num_atoms
        self.num_bonds = num_bonds

    def __call__(self, data: Data):
        atoms, bonds = torch.squeeze(data.x), torch.squeeze(data.edge_attr)
        data['x'] = F.one_hot(atoms, self.num_atoms).float()
        data['edge_attr'] = F.one_hot(bonds - 1, self.num_bonds).float()

        return data


def generate_loaders(args: Namespace) -> Tuple[DataLoader, DataLoader, DataLoader, int, int, int]:
    if args['dataset'] == 'zinc':
        from torch_geometric.datasets import ZINC
        num_atoms = 26
        num_bonds = 3

        if args["model_name"] != 'mpnn':
            transform = Compose([OneHotTransform(num_atoms, num_bonds), generate_geometric_transform(args)])
        else:
            transform = OneHotTransform(num_atoms, num_bonds)

        transform_att = FullyConnected() if args['layer_type'] == 'transformer' else None

        train_dataset = ZINC(f'datasets/{args["dataset"]}_{args["struc_info_type"]}', subset=True, split='train', pre_transform=transform)
        val_dataset = ZINC(f'datasets/{args["dataset"]}_{args["struc_info_type"]}', subset=True, split='val', pre_transform=transform)
        test_dataset = ZINC(f'datasets/{args["dataset"]}_{args["struc_info_type"]}', subset=True, split='test', pre_transform=transform)

    elif args['dataset'][:7] == 'peptide':
        from torch_geometric.datasets import LRGBDataset
        num_atoms = 9
        num_bonds = 3

        transform = Compose([AddSelfLoops(), generate_geometric_transform(args)])
        name = 'Peptides-struct' if args['dataset'][9:] == 'struct' else 'Peptides-func'

        train_dataset = LRGBDataset(f'datasets/{args["dataset"]}_{args["struc_info_type"]}', name=name, split='train', pre_transform=transform)
        val_dataset = LRGBDataset(f'datasets/{args["dataset"]}_{args["struc_info_type"]}', name=name, split='val', pre_transform=transform)
        test_dataset = LRGBDataset(f'datasets/{args["dataset"]}_{args["struc_info_type"]}', name=name, split='test', pre_transform=transform)
    else:
        raise ValueError(f'Do not recognize dataset {args["dataset"]}.')

    train_dataset = train_dataset.shuffle()

    # dataloaders
    train_loader = DataLoader(train_dataset, batch_size=args["batch_size"], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args["batch_size"], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args["batch_size"], shuffle=False)

    if hasattr(train_loader.dataset[0], 'p'):
        geom_size = train_loader.dataset[0].p.shape[1]
    else:
        geom_size = 0

    return train_loader, val_loader, test_loader, num_atoms, num_bonds, geom_size


def generate_geometric_transform(args):
    if args["struc_info_type"] == 'num_neighbours':
        from geometric_transforms.num_neighbours import NumNeighboursTransform
        transform = NumNeighboursTransform()
    if args["struc_info_type"] == 'laplacian':
        from geometric_transforms.laplacian import LaplacianTransform
        transform = LaplacianTransform(num_eigen_vec=args['struc_dim'])
    elif args["struc_info_type"] == 'random_walk':
        from geometric_transforms.random_walk import RandomWalkTransform
        transform = RandomWalkTransform(max_walk_len=args['struc_dim'])
    else:
        ValueError(f'Do not recognize geometric information type {args["struc_info_type"]}.')

    return transform


def generate_model(args: Namespace, feat_in: int, pos_in: int, edge_in: int) -> nn.Module:
    if args["model_name"] == 'pe_mpnn':
        from models.pe_mpnn import PE_MPNN
        model = PE_MPNN(
            feat_in=feat_in,
            pos_in=pos_in,
            feat_hidden=args["feat_hidden"],
            pos_hidden=args["pos_hidden"],
            num_out=args["num_out"],
            num_layers=args["num_layers"],
            state_type=args["state_type"],
            layer_type=args["layer_type"],
            ent_deg=args["ent_deg"],
            red=args["red"]
        )
    else:
        raise ValueError(f'Do not recognize model name {args["model_name"]}.')

    return model


def add_pos_information(data, feature):
    if hasattr(data, 'p'):
        p = torch.cat((data.p, feature), dim=1)
    else:
        p = feature

    data['p'] = p
    return data


def create_state(state_type, x, p):
    if state_type == 'concat':
        state = torch.cat((x, p), dim=1)
    elif state_type in ['tensor', 'ent_bigro']:
        state = torch.einsum("nh,np->nhp", x, p)
        state = torch.flatten(state, start_dim=1)
    else:
        raise ValueError(f'Do not recognize type {state_type}.')

    return state

def init_wandb(args):
    project_name = f"ICML-TAG-StrongStructuralEncodings"
    wandb.init(
        project=project_name,
    )
    wandb.config.update(args)

def get_criterion(args):
    if args['task'] == 'reg':
        criterion = nn.L1Loss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    return criterion

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


def flip_signs(data):
    sign_flip = torch.rand(data.p.size(1)).to(data.p.device)
    sign_flip[sign_flip >= 0.5] = 1.0
    sign_flip[sign_flip < 0.5] = -1.0
    data.p = data.p * sign_flip.unsqueeze(0)

    return data

def log_results(results):
    log = {}
    for split, split_res in results.items():
        for metric, value in split_res.items():
            log[f'{split} {metric}'] = value

    wandb.log(log)
