import argparse
import os
import sys
import numpy as np
import random

sys.path.append(os.getcwd())

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
import torch.nn.functional as F
import wandb
from torch_geometric.logging import log

import torch_geometric.transforms as T
import torch_geometric.utils as geo_utils
from dataset_handler import *
from get_models import get_model
from folder_dataloader import FolderDataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader

from train_utils import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import warnings
warnings.filterwarnings(
    'ignore', 
    category=UserWarning, 
    message='TypedStorage is deprecated')



def createEdgewiseData(data, edgewise_edge_index):
    edgeData = Data(
        x=data.x[data.edge_index[0]],
        edge_index=edgewise_edge_index
    )
    return edgeData


def get_node_feature(batch, data):
    edge_indices = batch.n_id
    node_indices = data.edge_index[:, edge_indices][...].unique()
    node_features = data.x[node_indices]
    y = data.y[node_indices] ## Y should just be the same as the final node index
    return node_indices, node_features, y


def get_edge_index(batch, data):
    def get_index_inverted_arr(node_indices):
        length = len(node_indices)
        arr = torch.arange(length).to(device)
        # return_arr = torch.zeros(node_indices.max()).to(device)
        return_arr = torch.zeros(data.x.shape[0]).to(device)
        return_arr = return_arr.type(node_indices.type())
        return_arr[node_indices] = arr
        return return_arr
    edge_indices = batch.n_id
    node_indices = data.edge_index[:, edge_indices][...].unique()
    index_inversion_list = get_index_inverted_arr(node_indices)
    subset_edge_index = data.edge_index[:, edge_indices]
    relabeled_edges = torch.zeros_like(subset_edge_index)

    relabeled_edges[0, :] = index_inversion_list[subset_edge_index[0, :]]
    relabeled_edges[1, :] = index_inversion_list[subset_edge_index[1, :]]
    return relabeled_edges


@torch.no_grad()
def test(model, data, loader, loss_fn, edgewise_edge_index=None):
    model.eval()
    test_accs = []
    val_accs = []
    test_losses = []
    val_losses = []
    for i, batch in enumerate(loader):
        batch = batch.to(device)
        test_mask = data.test_mask
        val_mask = data.val_mask
        if model.__class__.__name__ in [
            'BeliefPropLayers', 'GraphEdgeNetwork', 'EdgeGCN']:
            edge_features_init = batch.x
            node_indices, node_feature, y = get_node_feature(batch, data)
            relabeled_edges = get_edge_index(batch, data)
            out = model(
                x=edge_features_init,
                edge_index=relabeled_edges,
                node_feature=node_feature,
                edgewise_edge_index=batch.edge_index
            )
            t_mask = test_mask[node_indices]
            v_mask = val_mask[node_indices]
        else:
            out = model(batch.x.type(torch.FloatTensor).to(device), 
                        batch.edge_index, 
                        batch=batch.batch)
            y = batch.y
            t_mask = test_mask[batch.n_id]
            v_mask = val_mask[batch.n_id]
        test_loss = loss_fn(
            out[t_mask], y[t_mask]
        )
        val_loss = loss_fn(
            out[v_mask], y[v_mask]
        )
        test_losses.append(test_loss)
        val_losses.append(val_loss)
        pred = out.argmax(-1)
        test_accs.append(pred[t_mask] == y[t_mask])
        val_accs.append(pred[v_mask] == y[v_mask])
    
    test_loss = torch.mean(torch.tensor(test_losses))
    val_losses = torch.mean(torch.tensor(test_losses))

    val_accs = torch.hstack(val_accs)
    val_accs = val_accs.sum()/len(val_accs)
    test_accs = torch.hstack(test_accs)
    test_accs = test_accs.sum()/len(test_accs)

    return test_loss, val_loss, test_accs, val_accs



def train(model, data, loader, optimizer, loss_fn, 
          edgewise_edge_index=None, 
          scheduler=None, use_wandb=False):
    model.train()
    optimizer.zero_grad()

    accs = []
    losses = []
    for i, batch in enumerate(loader):
        batch = batch.to(device)
        mask = data.train_mask
        if model.__class__.__name__ in [
            'BeliefPropLayers', 'GraphEdgeNetwork', 'EdgeGCN']:
            edge_features_init = batch.x
            relabeled_edges = get_edge_index(batch, data)
            node_indices, node_feature, y = get_node_feature(batch, data)
            out = model(
                x=edge_features_init,
                edge_index=relabeled_edges,
                node_feature=node_feature,
                edgewise_edge_index=batch.edge_index
            )
            train_mask = mask[node_indices]
        else:
            out = model(batch.x.type(torch.FloatTensor).to(device), 
                        batch.edge_index, 
                        batch=batch.batch)
            y = batch.y
            train_mask = mask[batch.n_id]

        loss = loss_fn(
            out[train_mask], y[train_mask]
        )
        loss.backward()
        optimizer.step()
        pred = out.argmax(-1)
        accs.append(pred[train_mask] == y[train_mask])
        losses.append(loss.item())

        if scheduler is not None:
            scheduler.step()
        if use_wandb:
            if i % 10 == 0:
                # wandb.watch(model, log_freq=1, log='all')
                out_mean = out.detach().mean().item()
                out_var = out.detach().var().item()
                out_max = out.detach().max().item()
                out_min = out.detach().min().item()
                log_dict = {
                    "data_stats/out_mean": out_mean,
                    "data_stats/out_var": out_var,
                    "data_stats/out_max": out_max,
                    "data_stats/out_min": out_min,
                    "mid_train_loss/loss": loss.item(),
                }
                # log_dict.update(log_gradients(model))
                # log_dict.update(log_statistics(model))
                wandb.log(log_dict)
    full_loss = torch.mean(torch.tensor(losses))
    full_accs = torch.hstack(accs)
    full_accs = full_accs.sum()/len(full_accs)
    return full_loss, full_accs
        


def initialize_optimizer(model, args, optimizer_type='adam'):
    if optimizer_type == 'adam':
        optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=args.model.lr,
                weight_decay=args.model.weight_decay)
    elif optimizer_type == 'sgd':
        optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=args.model.lr,
                weight_decay=args.model.weight_decay)
    else:
        raise ValueError(f"Invalid optimizer option {optimizer_type}")
    if args.lr_schedule == 'cosine_lr':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs)
    elif args.lr_schedule == 'constant':
        scheduler = None
    else:
        ValueError("Invalid Lr schedule")
    return optimizer, scheduler

def log_gradients(model):
    total_norm = 0
    grad_dict = dict()
    for name, p in model.named_parameters():
        param_norm = p.grad.detach().data.norm(2)
        dict_key = f"gradients/{name}"
        grad_dict[dict_key] = param_norm.item()
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1/2)
    grad_dict["Loss/total_gradient"] = total_norm
    wandb.log(grad_dict)

def log_statistics(model):
    model_dict = dict()
    for name, p in model.named_parameters():
        if p is not None:
            dict_key = f"model_max/{name}"
            p_max = torch.max(p).detach().item()
            model_dict[dict_key] = p_max
            dict_key = f"model_min/{name}"
            p_min = torch.min(p).detach().item()
            model_dict[dict_key] = p_min
            dict_key = f"model_mean/{name}"
            p_mean = torch.mean(p).detach().item()
            model_dict[dict_key] = p_mean
    wandb.log(model_dict)


@hydra.main(config_path="./configs", config_name="config")
def main(args: DictConfig):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    if args.use_wandb:
        wandb_config = OmegaConf.to_container(
            args, resolve=True, throw_on_missing=True)
        if args.wandb.tags is not None:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                config=wandb_config,
                tags=list(args.wandb.tags))
        else:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                config=wandb_config) 
    if args.log_model:
        model_folder_path = os.path.join(
            args.workdir.root,
            args.workdir.name,)
        os.makedirs(model_folder_path, exist_ok=True)

    dataset, data = get_dataset(args)
    data = data.to(device)
    print(f"num edges: {data.edge_index.shape}")

    if args.dataset.loss_type == 'cross_entropy_loss':
        loss_fn = nn.CrossEntropyLoss()
        # num_features = dataset.num_features
        # num_classes = dataset.num_classes
        num_features = data.x.shape[-1]
        num_classes = len(data.y.unique())

        homophily_level = geo_utils.homophily(data.edge_index, data.y)
        print(f'Homophily level of the dataset is {homophily_level}')
        if args.use_wandb:
            wandb.log({
                "Homophily Level": homophily_level,
                # "data_stats/gt_mean": mean_y,
                # "data_stats/gt_var": var_y,
            })
    elif args.dataset.loss_type == 'l2_loss':
        loss_fn = nn.MSELoss()
        num_features = 1
        num_classes = 1
    else:
        ValueError("Invalid Loss Type")
    

    model = get_model(
        args.model, 
        num_features=num_features,
        num_classes=num_classes)
    model = model.to(device)
    optimizer, scheduler = initialize_optimizer(
        model, args, optimizer_type=args.optimizer)

    # we are not going to use the edge_attr for a while.
    if model.__class__.__name__ in ['GraphEdgeNetwork', 'EdgeGCN']:
        edgewise_edge_index = get_edgewise_graph(data.edge_index)
        edgeData = createEdgewiseData(data, edgewise_edge_index)
        loader = NeighborLoader(
            edgeData,
            num_neighbors=[-1],
            batch_size=args.dataset.batch_size,
            shuffle=True,
        )
        test_loader = NeighborLoader(
            edgeData,
            num_neighbors=[-1],
            batch_size=args.dataset.batch_size,
            shuffle=False,
        )
    else: 
        edgewise_edge_index = None
        loader = NeighborLoader(
            data, 
            num_neighbors=[-1],
            batch_size=128,
            shuffle=True,
        )
        test_loader = NeighborLoader(
            data,
            num_neighbors=[-1],
            batch_size=args.dataset.batch_size,
            shuffle=False,
        )


    total_params = sum(p.numel() for p in model.parameters())
    total_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total parameters = {total_params} require_grad {total_grad_params}')
    
    best_train_acc = 0
    best_val_acc = 0
    best_test_acc = 0
    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = train(
            model, data, loader, 
            optimizer, loss_fn, 
            edgewise_edge_index=edgewise_edge_index,
            scheduler=scheduler, 
            use_wandb=args.use_wandb)
        test_loss, val_loss, test_acc, val_acc = test(
            model, data, test_loader, loss_fn, 
            edgewise_edge_index=edgewise_edge_index)
        if args.print:
            if loss_fn == 'l2_loss':
                log(Epoch=epoch, 
                    Loss=train_loss, 
                    Train=train_acc, 
                    Val=val_acc, 
                    Test=test_acc)
            else:
                log(Epoch=epoch, 
                    train_loss=train_loss,
                    val_loss=val_loss,
                    test_loss=test_loss,
                    train_acc=train_acc,
                    test_acc=test_acc,)

        if args.use_wandb:
            if loss_fn == 'l2_loss':
                wandb.log({
                    'Loss/train': train_loss, 
                    'Loss/val': val_loss, 
                    'Loss/test': test_loss,
                    'epoch': epoch,
                })
                log_gradients(model)
                log_statistics(model)
            else:
                wandb.log({
                    'Loss/train': train_loss, 
                    'Loss/val': val_loss, 
                    'Loss/test': test_loss,
                    'Loss/test_acc':test_acc,
                    'Loss/train_acc':train_acc,
                    'epoch': epoch,
                })
                log_gradients(model)
                log_statistics(model)
        if args.log_model:
            if epoch % 100 == 0:
                ckpt_path = os.path.join(args.workdir.root, args.workdir.name)
                torch.save(
                    model, 
                    os.path.join(ckpt_path, f'checkpoint_{epoch:04}.pth'))
    if args.log_model:
        torch.save(model, args.workdir.checkpoint_path)
    return model

if __name__=='__main__':
    model = main()
