import argparse
import os
import sys
import numpy as np
import random

sys.path.append(os.getcwd())

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
import torch.nn.functional as F
import wandb
from torch_geometric.logging import log

import torch_geometric.transforms as T
import torch_geometric.utils as geo_utils
from dataset_handler import *
from get_models import get_model
from folder_dataloader import FolderDataLoader
from torch_geometric.loader import DataLoader

from train_utils import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import warnings
warnings.filterwarnings(
    'ignore', 
    category=UserWarning, 
    message='TypedStorage is deprecated')

@torch.no_grad()
def test(model, test_loader, loss_fn):
    model.eval()
    loss = []
    accs = []
    for data in test_loader:
        data = data.to(device)
        mask = data.test_mask
        if model.__class__.__name__ in [
            'BeliefPropLayers', 
            'BeliefPropLayersImprovedMessagePassing', 
            'GraphEdgeNetwork', 
            'GraphEdgeConvNetwork']:
            if model.__class__.__name__ in [
                'GraphEdgeNetwork', 'BeliefPropLayersImprovedMessagePassing']:
                edgewise_edge_index = data.edgewise_edge_index
                edge_feature_init = get_edge_initialization(
                    data, init_w_zeros=True)
            elif model.__class__.__name__ == 'GraphEdgeConvNetwork':
                # edgewise_edge_index = get_edgewise_graph(data.edge_index)
                edgewise_edge_index = data.edgewise_edge_index
                edge_feature_init = get_edge_initialization(
                    data, init_w_zeros=False)
            else: 
                edgewise_edge_index = None
                edge_feature_init = get_edge_initialization(
                    data, init_w_zeros=True)
            out = model(
                x=edge_feature_init,
                edge_index=data.edge_index,
                node_feature=data.x.type(torch.FloatTensor).to(device),
                edgewise_edge_index=edgewise_edge_index,
                batch=data.batch,
            )
        else:
            out = model(data.x.type(torch.FloatTensor).to(device), 
                        data.edge_index, 
                        batch=data.batch)
        loss_val = float(loss_fn(out[mask], data.y[mask]))
        loss.append(loss_val)
        pred = out.argmax(dim=-1)
        accs.append(pred[mask] == data.y[mask])
    test_loss = torch.mean(torch.tensor(loss))
    full_accs = torch.hstack(accs)
    test_accs = full_accs.sum()/len(full_accs)
    return test_loss, test_accs

def train(model, train_loader, optimizer, loss_fn, 
          epoch, scheduler=None, use_wandb=False):
    model.train()
    optimizer.zero_grad()
    full_loss = []
    full_accs = []
    for i, data in enumerate(train_loader):
        data = data.to(device)
        mask = data.train_mask
        if model.__class__.__name__ in [
            'BeliefPropLayers', 
            'BeliefPropLayersImprovedMessagePassing', 
            'GraphEdgeNetwork', ]:
            if model.__class__.__name__ in [
                'GraphEdgeNetwork', 'BeliefPropLayersImprovedMessagePassing']:
            # edgewise_edge_index = get_edgewise_graph(data.edge_index)
                edgewise_edge_index = data.edgewise_edge_index
                edge_feature_init = get_edge_initialization(
                    data, init_w_zeros=True)
            else:
                edgewise_edge_index = None
                edge_feature_init = get_edge_initialization(
                    data, init_w_zeros=True)
            out = model(
                x=edge_feature_init,
                edge_index=data.edge_index,
                node_feature=data.x.type(torch.FloatTensor).to(device),
                edgewise_edge_index=edgewise_edge_index,
                batch=data.batch,
            )
        else:
            out = model(data.x.type(torch.FloatTensor).to(device), 
                        data.edge_index, 
                        batch=data.batch)
        loss = loss_fn(out[mask], data.y[mask])
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        full_loss.append(loss.item())
        pred = out.argmax(dim=-1)
        full_accs.append(pred[mask] == data.y[mask])
        if use_wandb:
            if i % 10 == 0:
                # wandb.watch(model, log_freq=1, log='all')
                out_mean = out.detach().mean().item()
                out_var = out.detach().var().item()
                out_max = out.detach().max().item()
                out_min = out.detach().min().item()
                log_dict = {
                    "data_stats/out_mean": out_mean,
                    "data_stats/out_var": out_var,
                    "data_stats/out_max": out_max,
                    "data_stats/out_min": out_min,
                    "mid_train_loss/loss": loss.item(),
                }
                log_dict.update(log_gradients(model))
                log_dict.update(log_statistics(model))
                wandb.log(log_dict)

    # model_gradient = calculate_total_model_gradient_norm(model)
    # if model_gradient < 1e-5:
    full_loss = torch.mean(torch.tensor(full_loss))
    full_accs = torch.hstack(full_accs)
    full_accs = full_accs.sum()/len(full_accs)
    return full_loss, full_accs

def get_dataset(args):
    if args.dataset.name in ['cora', 'cora_2']:
        data_name = 'cora'
    elif args.dataset.name == 'WikipediaNetwork':
        data_name = args.dataset.sub_name
    else:
        ValueError(f"{args.dataset.name} is not considered")

    train_dataloader = FolderDataLoader(
        data_name=data_name,
        mode='train',
    )
    num_classes = train_dataloader.num_classes
    num_features = train_dataloader.num_features
    test_dataloader = FolderDataLoader(
        data_name=data_name,
        mode='test',
    )
    train_loader = DataLoader(
        train_dataloader, 
        batch_size=args.dataset.batch_size,
        num_workers=args.num_workers, 
        drop_last=True,
        shuffle=True)
    test_loader = DataLoader(
        test_dataloader, 
        batch_size=args.dataset.batch_size,
        num_workers=args.num_workers, 
        drop_last=True,
        shuffle=False)
    return train_loader, test_loader, num_features, num_classes

def initialize_optimizer(model, args, optimizer_type='adam'):
    if optimizer_type == 'adam':
        optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=args.model.lr,
                weight_decay=args.model.weight_decay)
    elif optimizer_type == 'sgd':
        optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=args.model.lr,
                weight_decay=args.model.weight_decay)
    else:
        raise ValueError(f"Invalid optimizer option {optimizer_type}")
    if args.lr_schedule == 'cosine_lr':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs)
    elif args.lr_schedule == 'constant':
        scheduler = None
    else:
        ValueError("Invalid Lr schedule")
    return optimizer, scheduler


def calculate_total_model_gradient_norm(model):
    total_norm = 0
    for name, p in model.named_parameters():
        if p.grad is not None:
            param_norm = p.grad.detach().data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1/2)
    return total_norm

def log_gradients(model):
    total_norm = 0
    grad_dict = dict()
    for name, p in model.named_parameters():
        if p.grad is not None:
            param_norm = p.grad.detach().data.norm(2)
            dict_key = f"gradients/{name}"
            grad_dict[dict_key] = param_norm.item()
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1/2)
    grad_dict["Loss/total_gradient"] = total_norm
    return grad_dict

def log_statistics(model):
    model_dict = dict()
    for name, p in model.named_parameters():
        dict_key = f"model_max/{name}"
        p_max = torch.max(p).detach().item()
        model_dict[dict_key] = p_max
        dict_key = f"model_min/{name}"
        p_min = torch.min(p).detach().item()
        model_dict[dict_key] = p_min
        dict_key = f"model_mean/{name}"
        p_mean = torch.mean(p).detach().item()
        model_dict[dict_key] = p_mean
    return model_dict




@hydra.main(config_path="./configs", config_name="config")
def main(args: DictConfig):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    if args.use_wandb:
        wandb_config = OmegaConf.to_container(
            args, resolve=True, throw_on_missing=True)
        if args.wandb.tags is not None:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                config=wandb_config,
                tags=list(args.wandb.tags))
        else:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                config=wandb_config) 

    if args.log_model:
        model_folder_path = os.path.join(
            args.workdir.root,
            args.workdir.name,)
        os.makedirs(model_folder_path, exist_ok=True)

    if args.num_threads > 0:
        torch.set_num_threads(args.num_threads)

    train_loader, test_loader, num_features, num_classes = get_dataset(args)
    # For now this is just going to be used with ising models so we are fine.
    if args.dataset.loss_type == 'cross_entropy_loss':
        loss_fn = nn.CrossEntropyLoss()
    elif args.dataset.loss_type == 'l2_loss':
        loss_fn = nn.MSELoss()
        num_features = 1
        num_classes = 1
    else:
        ValueError("Invalid Loss Type")

    model = get_model(
        args.model, 
        num_features=num_features,
        num_classes=num_classes,)
    model = model.to(device)
    optimizer, scheduler = initialize_optimizer(
        model, args, optimizer_type=args.optimizer)

    total_params = sum(p.numel() for p in model.parameters())
    total_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total parameters = {total_params} require_grad {total_grad_params}')

    for epoch in range(1, args.epochs + 1):
        train_loss, train_accs = train(model, train_loader, 
                           optimizer, loss_fn, 
                           epoch=epoch,
                           scheduler=scheduler,
                           use_wandb=args.use_wandb)
        test_loss, test_accs = test(model, test_loader, loss_fn)
        print(epoch)
        if args.print:
            log(Epoch=epoch, 
                train_loss=train_loss,
                train_accs=train_accs,
                test_loss=test_loss,
                test_accs=test_accs,)
        # exit the run if we get NaN values for an epoch. 
        if np.isnan(train_loss):
            if args.use_wandb:
                run.tags = run.tags + ("nan_exit",)
            if args.log_model:
                torch.save(model, args.workdir.checkpoint_path)
            exit()
        if args.use_wandb:
            wandb.log({
                'Loss/train': train_loss, 
                'Loss/test': test_loss,
                'Loss/test_acc':test_accs,
                'Loss/train_acc':train_accs,
                'epoch': epoch,
            })
            # log_gradients(model)
            # log_statistics(model)
        if args.log_model:
            if epoch % 5 == 0:
                ckpt_path = os.path.join(args.workdir.root, args.workdir.name)
                torch.save(
                    model, 
                    os.path.join(ckpt_path, f'checkpoint_{epoch:04}.pth'))
                # save almost everything
                torch.save(model, args.workdir.checkpoint_path)
    if args.log_model:
        torch.save(model, args.workdir.checkpoint_path)
    return model

if __name__=='__main__':
    model = main()

