import argparse
import os
import sys
import numpy as np
import random

sys.path.append(os.getcwd())

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch_geometric.logging import log

import torch_geometric
import torch_geometric.transforms as T
from inductive_data_handler import get_dataloder
from torch_geometric.loader import DataLoader
import train_utils


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import warnings
warnings.filterwarnings(
    'ignore', 
    category=UserWarning, 
    message='TypedStorage is deprecated')


def get_data(batch, loader, has_edgewise_graph):
    if has_edgewise_graph:
        data, edgewise_graph = batch
        return data.to(device), edgewise_graph.to(device)
    else:
        return batch.to(device), None


def calculate_accuracy(output, target):
    # make sure target is LongTensor
    target = target.type(torch.LongTensor).to(device)
    preds = output.argmax(-1)
    acc = (preds == target)
    return acc


@torch.no_grad()
def test(model, test_loader, loss_fn, args=None):
    model.eval()
    loss = []
    if args.model.params.task_type in ['graph_classification']:
        accs = []
    for data in test_loader:
        data, edgewise_graph = get_data(data, test_loader, 
                                        args.dataset.has_edgewise_graph)
        data.y = data.y.type(torch.FloatTensor).to(device)
        if args.model.edge_based:
            if args.dataset.has_edgewise_graph:
                edgewise_edge_index = edgewise_graph.edge_index
            else:
                edgewise_edge_index = train_utils.get_edgewise_graph(data.edge_index)
                edgewise_edge_index = edgewise_edge_index.to(device)

            if hasattr(data, 'edge_attr') and (getattr(data, 'edge_attr') is not None):
                edge_feature_init = data.edge_attr
                edge_feature_init = edge_feature_init.unsqueeze(dim=-1)
            else:
                edge_feature_init = train_utils.get_edge_initialization(
                    data, init_type="zeros")
            # NOTE: can't send edgewise_graph since some datasets may not have this edgwise_edge_index
            out = model(
                x=edge_feature_init,
                data=data,
                edgewise_edge_index=edgewise_edge_index, 
            )
        else:
            out = model(data)
        loss_val = float(loss_fn(out.flatten(), data.y.flatten()))
        loss.append(loss_val)
        if args.model.params.task_type in ['graph_classification']:
            accs.append(calculate_accuracy(out.detach(), data.y))

    test_loss = torch.mean(torch.tensor(loss))
    if args.model.params.task_type in ['graph_classification']:
        accs = torch.cat(accs)
        total_acc = accs.sum()/len(accs)
    else:
        total_acc = None
    
    return test_loss, total_acc

def train(model, train_loader, optimizer, loss_fn, 
          epoch, scheduler=None, use_wandb=False, args=None):
    model.train()
    full_loss = []
    if args.model.params.task_type in ['graph_classification']:
        accs = []

    for i, batch in enumerate(train_loader):
        data, edgewise_graph = get_data(batch, train_loader, 
                                        args.dataset.has_edgewise_graph)
        data.y = data.y.type(torch.FloatTensor).to(device)
        if args.model.edge_based:
            if args.dataset.has_edgewise_graph:
                edgewise_edge_index = edgewise_graph.edge_index
            else:
                edgewise_edge_index = train_utils.get_edgewise_graph(data.edge_index)
                edgewise_edge_index = edgewise_edge_index.to(device)

            if hasattr(data, 'edge_attr') and (getattr(data, 'edge_attr') is not None):
                edge_feature_init = data.edge_attr
                edge_feature_init = edge_feature_init.unsqueeze(dim=-1)
            else:
                edge_feature_init = train_utils.get_edge_initialization(
                    data, init_type="zeros")

            # NOTE: can't send edgewise_graph since some datasets may not have this edgwise_edge_index
            out = model(
                x=edge_feature_init,
                data=data,
                edgewise_edge_index=edgewise_edge_index, 
            )
        else:
            out = model(data)


        optimizer.zero_grad()
        loss = loss_fn(out.flatten(), data.y.flatten())
        loss.backward()
        optimizer.step()

        if args.model.params.task_type in ['graph_classification']:
            acc = calculate_accuracy(out.detach(), data.y)
            accs.append(acc)

        if scheduler is not None: 
            if scheduler.__class__.__name__ not in ["ReduceLROnPlateau"]:
                scheduler.step()
        full_loss.append(loss.item())
        if use_wandb:
            if i % 10 == 0:
                # wandb.watch(model, log_freq=1, log='all')
                out_mean = out.detach().mean().item()
                out_var = out.detach().var().item()
                out_max = out.detach().max().item()
                out_min = out.detach().min().item()
                log_dict = {
                    "data_stats/out_mean": out_mean,
                    "data_stats/out_var": out_var,
                    "data_stats/out_max": out_max,
                    "data_stats/out_min": out_min,
                    "mid_train_loss/loss": loss.item(),
                }
                log_dict.update(train_utils.log_gradients(model))
                log_dict.update(train_utils.log_statistics(model))
                wandb.log(log_dict)
    full_loss = torch.mean(torch.tensor(full_loss))

    if args.model.params.task_type in ['graph_classification']:
        accs = torch.cat(accs)
        total_acc = accs.sum()/len(accs)
    else:
        total_acc = None
    return full_loss, total_acc

def get_dataset(args):
    train_dataloader, test_dataloader = get_dataloder(args)
    train_loader = DataLoader(
        train_dataloader, 
        batch_size=args.dataset.batch_size,
        num_workers=args.num_workers, 
        drop_last=True,
        shuffle=True)
    test_loader = DataLoader(
        test_dataloader, 
        batch_size=args.dataset.batch_size,
        num_workers=args.num_workers, 
        drop_last=True,
        shuffle=False)
    return train_loader, test_loader



@hydra.main(config_path="./configs", config_name="config")
def main(args: DictConfig):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    if args.use_wandb:
        wandb_config = OmegaConf.to_container(
            args, resolve=True, throw_on_missing=True)
        if args.wandb.tags is not None:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                config=wandb_config,
                entity=args.wandb.entity,
                tags=[args.wandb.tags])
        else:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                entity=args.wandb.entity,
                config=wandb_config) 

    if args.log_model:
        model_folder_path = os.path.join(
            args.workdir.root,
            args.workdir.name,)
        os.makedirs(model_folder_path, exist_ok=True)

    if args.num_threads > 0:
        torch.set_num_threads(args.num_threads)

    train_loader, test_loader = get_dataset(args)

    loss_fn = instantiate(args.dataset.loss_type)
        
    model = instantiate(
        args.model.params,
        in_channels=args.dataset.num_features,
        out_channels=args.dataset.num_classes,
        _recursive_=False
    )

    model = model.to(device)
    optimizer, scheduler = train_utils.initialize_optimizer(
        model, args, optimizer_type=args.optimizer)

    total_params = sum(p.numel() for p in model.parameters())
    total_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total parameters = {total_params} require_grad {total_grad_params}')

    for epoch in range(1, args.epochs + 1):
        train_loss, total_train_acc = train(model, train_loader, 
                           optimizer, loss_fn, 
                           epoch=epoch,
                           scheduler=scheduler,
                           use_wandb=args.use_wandb, 
                           args=args,)
        test_loss, total_test_acc = test(model, test_loader, loss_fn, args=args)
        if scheduler is not None:
            if scheduler.__class__.__name__ in ["ReduceLROnPlateau"]:
                scheduler.step(test_loss)
        if args.print:
            print(f"Epoch Number: {epoch}")
            if total_train_acc is not None:
                log(Epoch=epoch, 
                    train_loss=train_loss,
                    test_loss=test_loss,
                    train_acc=total_train_acc,
                    test_acc=total_test_acc)
            else:
                log(Epoch=epoch, 
                    train_loss=train_loss,
                    test_loss=test_loss,)
        # exit the run if we get NaN values for an epoch. 
        if np.isnan(train_loss):
            if args.use_wandb:
                run.tags = run.tags + ("nan_exit",)
            if args.log_model:
                torch.save(model, args.workdir.checkpoint_path)
            exit()
        if args.use_wandb:
            curr_lr = optimizer.param_groups[0]['lr']
            if total_train_acc is not None:
                wandb.log({
                    'Loss/train': train_loss, 
                    'Loss/test': test_loss,
                    'Acc/train': total_train_acc,
                    'Acc/test': total_test_acc,
                    'params/learning_rate': curr_lr,
                    'epoch': epoch,
                })
            else:
                wandb.log({
                    'Loss/train': train_loss, 
                    'Loss/test': test_loss,
                    'params/learning_rate': curr_lr,
                    'epoch': epoch,
                })
        if args.log_model:
            if epoch % 5 == 0:
                ckpt_path = os.path.join(args.workdir.root, args.workdir.name)
                torch.save(
                    model, 
                    os.path.join(ckpt_path, f'checkpoint_{epoch:04}.pth'))
                # save almost everything
                torch.save(model, args.workdir.checkpoint_path)
    if args.log_model:
        torch.save(model, args.workdir.checkpoint_path)
    return model

if __name__=='__main__':
    model = main()

