import argparse
import os
import sys
import numpy as np
import random

sys.path.append(os.getcwd())

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
import torch.nn.functional as F
import wandb
from torch_geometric.logging import log

import torch_geometric.transforms as T
import torch_geometric.utils as geo_utils
from dataset_handler import *
from get_models import get_model
from train_utils import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


@torch.no_grad()
def test(model, data, loss_fn, edgewise_edge_index=None):
    model.eval()
    if model.__class__.__name__ in [
        'BeliefPropLayers', 'GraphEdgeNetwork', 'EdgeGCN']:
        # out = model(torch.tanh(data.x), data.edge_index, torch.tanh(data.x))
        edge_features_init = get_edge_initialization(data)
        out = model(
            x=edge_features_init,
            edge_index=data.edge_index,
            node_feature=data.x,
            edgewise_edge_index=edgewise_edge_index,
        )
    else:
        out = model(data.x, data.edge_index)
    pred = out.argmax(dim=-1)
    accs = []
    loss = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
        loss.append(float(loss_fn(out[mask], data.y[mask])))
    return accs, loss

def train(model, data, optimizer, loss_fn, 
          edgewise_edge_index=None, 
          scheduler=None, use_wandb=False):
    model.train()
    optimizer.zero_grad()

    if model.__class__.__name__ in [
        'BeliefPropLayers', 'GraphEdgeNetwork', 'EdgeGCN']:
        edge_features_init = get_edge_initialization(data)
        out = model(
            x=edge_features_init,
            edge_index=data.edge_index,
            node_feature=data.x,
            edgewise_edge_index=edgewise_edge_index,
        )
    else:
        out = model(data.x.type(torch.FloatTensor).to(device), 
                    data.edge_index, 
                    batch=data.batch)
    loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if scheduler is not None:
        scheduler.step()
    if use_wandb:
        out_mean = out.detach().mean().item()
        out_var = out.detach().var().item()
        out_max = out.detach().max().item()
        out_min = out.detach().min().item()
        wandb.log({
            "data_stats/out_mean": out_mean,
            "data_stats/out_var": out_var,
            "data_stats/out_max": out_max,
            "data_stats/out_min": out_min,
        })
    return float(loss)

def initialize_optimizer(model, args, optimizer_type='adam'):
    if optimizer_type == 'adam':
        optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=args.model.lr,
                weight_decay=args.model.weight_decay)
    elif optimizer_type == 'sgd':
        optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=args.model.lr,
                weight_decay=args.model.weight_decay)
    else:
        raise ValueError(f"Invalid optimizer option {optimizer_type}")
    if args.lr_schedule == 'cosine_lr':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs)
    elif args.lr_schedule == 'constant':
        scheduler = None
    else:
        ValueError("Invalid Lr schedule")
    return optimizer, scheduler

def log_gradients(model):
    total_norm = 0
    grad_dict = dict()
    for name, p in model.named_parameters():
        param_norm = p.grad.detach().data.norm(2)
        dict_key = f"gradients/{name}"
        grad_dict[dict_key] = param_norm.item()
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1/2)
    grad_dict["Loss/total_gradient"] = total_norm
    wandb.log(grad_dict)

def log_statistics(model):
    model_dict = dict()
    for name, p in model.named_parameters():
        if p is not None:
            dict_key = f"model_max/{name}"
            p_max = torch.max(p).detach().item()
            model_dict[dict_key] = p_max
            dict_key = f"model_min/{name}"
            p_min = torch.min(p).detach().item()
            model_dict[dict_key] = p_min
            dict_key = f"model_mean/{name}"
            p_mean = torch.mean(p).detach().item()
            model_dict[dict_key] = p_mean
    wandb.log(model_dict)

@hydra.main(config_path="./configs", config_name="config")
def main(args: DictConfig):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    if args.use_wandb:
        wandb_config = OmegaConf.to_container(
            args, resolve=True, throw_on_missing=True)
        if args.wandb.tags is not None:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                config=wandb_config,
                tags=list(args.wandb.tags))
        else:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                config=wandb_config) 
    if args.log_model:
        model_folder_path = os.path.join(
            args.workdir.root,
            args.workdir.name,)
        os.makedirs(model_folder_path, exist_ok=True)

    dataset, data = get_dataset(args)
    data = data.to(device)

    if args.dataset.loss_type == 'cross_entropy_loss':
        loss_fn = nn.CrossEntropyLoss()
        # num_features = dataset.num_features
        # num_classes = dataset.num_classes
        num_features = data.x.shape[-1]
        num_classes = len(data.y.unique())

        homophily_level = geo_utils.homophily(data.edge_index, data.y)
        print(f'Homophily level of the dataset is {homophily_level}')
        if args.use_wandb:
            wandb.log({
                "Homophily Level": homophily_level,
                # "data_stats/gt_mean": mean_y,
                # "data_stats/gt_var": var_y,
            })
    elif args.dataset.loss_type == 'l2_loss':
        loss_fn = nn.MSELoss()
        num_features = 1
        num_classes = 1
    else:
        ValueError("Invalid Loss Type")
    

    model = get_model(
        args.model, 
        num_features=num_features,
        num_classes=num_classes)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)

    model = model.to(device)
    optimizer, scheduler = initialize_optimizer(
        model, args, optimizer_type=args.optimizer)

    # we are not going to use the edge_attr for a while.
    if model.__class__.__name__ in ['GraphEdgeNetwork', 'EdgeGCN']:
        edgewise_edge_index = get_edgewise_graph(data.edge_index)
    else: 
        edgewise_edge_index = None

    total_params = sum(p.numel() for p in model.parameters())
    total_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total parameters = {total_params} require_grad {total_grad_params}')
    
    best_train_acc = 0
    best_val_acc = 0
    best_test_acc = 0
    for epoch in range(1, args.epochs + 1):
        loss = train(model, data, optimizer, loss_fn, 
                     edgewise_edge_index=edgewise_edge_index,
                     scheduler=scheduler, use_wandb=args.use_wandb)
        acc, loss = test(model, data, loss_fn, 
                         edgewise_edge_index=edgewise_edge_index)
        train_acc, val_acc, test_acc = acc
        train_loss, val_loss, test_loss = loss
        best_train_acc = get_best_val(best_train_acc, train_acc)
        best_val_acc = get_best_val(best_val_acc, val_acc)
        best_test_acc = get_best_val(best_test_acc, test_acc)
        if args.print:
            if loss_fn == 'l2_loss':
                log(Epoch=epoch, 
                    Loss=loss, 
                    Train=train_acc, 
                    Val=val_acc, 
                    Test=test_acc)
            else:
                log(Epoch=epoch, 
                    train_loss=train_loss,
                    val_loss=val_loss,
                    test_loss=test_loss,
                    train_acc=train_acc,
                    test_acc=test_acc,)

        if args.use_wandb:
            if loss_fn == 'l2_loss':
                wandb.log({
                    'Loss/train': train_loss, 
                    'Loss/val': val_loss, 
                    'Loss/test': test_loss,
                    'epoch': epoch,
                })
                log_gradients(model)
                log_statistics(model)
            else:
                wandb.log({
                    'Loss/train': train_loss, 
                    'Loss/val': val_loss, 
                    'Loss/test': test_loss,
                    'Loss/test_acc':test_acc,
                    'Loss/train_acc':train_acc,
                    'epoch': epoch,
                })
                log_gradients(model)
                log_statistics(model)
        if args.log_model:
            if epoch % 100 == 0:
                ckpt_path = os.path.join(args.workdir.root, args.workdir.name)
                torch.save(
                    model, 
                    os.path.join(ckpt_path, f'checkpoint_{epoch:04}.pth'))
    if args.log_model:
        torch.save(model, args.workdir.checkpoint_path)
    return model

if __name__=='__main__':
    model = main()
