import argparse
import os
import sys
import numpy as np
import random

sys.path.append(os.getcwd())

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

import torch
import torch.nn.functional as F
import wandb
from torch_geometric.logging import log

import torch_geometric.transforms as T
import torch_geometric.utils as geo_utils
from dataset_handler import *
# from get_models import get_model
from train_utils import *
from utils import *

@torch.no_grad()
def test(model, data, loss_fn, edgewise_edge_index=None, args=None):
    model.eval()

    if args.model.edge_based: 
        assert edgewise_edge_index is not None, (
            f"{args.model.model} requires edgewise edge index")
        edge_features_init = get_edge_initialization(
            data, init_type='data_edge')
        out = model(
            x=edge_features_init,
            data=data,
            edgewise_edge_index=edgewise_edge_index,
        )
    else:
        out = model(data)
    pred = out.argmax(dim=-1)
    accs = []
    loss = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
        loss.append(float(loss_fn(out[mask], data.y[mask])))

    if args.dataset.name in ['flipflop', 'd_regular_tree']:
        kl_loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
        logsoftmax_out = F.log_softmax(out, dim=1)[data.test_mask]
        # dm = torch.tensor(data.marginals[data.test_mask]).type(Tensor).to(device)
        data.marginals = torch.tensor(data.marginals).type(Tensor).to(device)
        # kl_loss = float(kl_loss_fn(logsoftmax_out[data.test_mask], 
        #                            data.marginals[data.test_mask]))
        true_marginals = data.marginals[data.test_mask]
        kl_loss = true_marginals * (torch.log(true_marginals) - logsoftmax_out)
        kl_loss = kl_loss.sum(dim=1).mean()
        return accs, loss, kl_loss
    return accs, loss, 0

def train(model, data, optimizer, loss_fn, 
          edgewise_edge_index=None, 
          scheduler=None, use_wandb=False, args=None):
    model.train()
    optimizer.zero_grad()

    if args.model.edge_based: 
        assert edgewise_edge_index is not None, (
            f"{args.model.model} requires edgewise edge index")

        edge_features_init = get_edge_initialization(
            data, init_type='data_edge')
        out = model(
            x=edge_features_init,
            data=data,
            edgewise_edge_index=edgewise_edge_index,
        )
    else:
        out = model(data)
    loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if scheduler is not None: 
        if scheduler.__class__.__name__ not in ["ReduceLROnPlateau"]:
            scheduler.step()
    if use_wandb:
        out_mean = out.detach().mean().item()
        out_var = out.detach().var().item()
        out_max = out.detach().max().item()
        out_min = out.detach().min().item()
        wandb.log({
            "data_stats/out_mean": out_mean,
            "data_stats/out_var": out_var,
            "data_stats/out_max": out_max,
            "data_stats/out_min": out_min,
        })
    return loss.type(Tensor)

@hydra.main(config_path="./configs", config_name="config")
def main(args: DictConfig):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.set_num_threads(4)

    if args.use_wandb:
        wandb_config = OmegaConf.to_container(
            args, resolve=True, throw_on_missing=True)
        if args.wandb.tags is not None:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                entity=args.wandb.entity,
                config=wandb_config,
                tags=list(args.wandb.tags))
        else:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                entity=args.wandb.entity,
                config=wandb_config) 
    if args.log_model:
        model_folder_path = os.path.join(
            args.workdir.root,
            args.workdir.name,)
        os.makedirs(model_folder_path, exist_ok=True)

    data = get_dataset(args)
    # data.x = data.x.type(torch.float16)
    data = data.to(device)

    loss_fn = instantiate(args.dataset.loss_type)
    if not hasattr(args.dataset, 'num_features'):
        num_features = data.x.shape[-1]
        num_classes = len(data.y.unique())
    if args.dataset.log_homophily: 
        homophily_level = geo_utils.homophily(data.edge_index, data.y)
        print(f'Homophily level of the dataset is {homophily_level}')
        if args.use_wandb:
            wandb.log({
                "Homophily Level": homophily_level,
            })

    model = instantiate(
        args.model.params,
        in_channels=num_features,
        out_channels=num_classes, 
        _recursive_=False)

    model = model.to(device)
    optimizer, scheduler = initialize_optimizer(
        model, args, optimizer_type=args.optimizer, transductive=True)

    # we are not going to use the edge_attr for a while.
    # has_edgewise_graph: False
    if args.dataset.has_edgewise_graph:
        edgewise_edge_index = data.edgewise_edge_index.to(device)
    else:
        edgewise_edge_index = get_edgewise_edge_index(data.edge_index).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    total_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total parameters = {total_params} require_grad {total_grad_params}')
    
    best_train_acc = 0
    best_val_acc = 0
    best_test_acc = 0
    for epoch in range(1, args.epochs + 1):
        loss = train(model, data, optimizer, loss_fn, 
                     edgewise_edge_index=edgewise_edge_index,
                     scheduler=scheduler, use_wandb=args.use_wandb, 
                     args=args)
        acc, loss, kl_loss = test(model, data, loss_fn, 
                         edgewise_edge_index=edgewise_edge_index, 
                         args=args)
        train_acc, val_acc, test_acc = acc
        train_loss, val_loss, test_loss = loss

        if scheduler is not None:
            if scheduler.__class__.__name__ in ["ReduceLROnPlateau"]:
                scheduler.step(val_loss)

        best_train_acc = get_best_val(best_train_acc, train_acc)
        best_val_acc = get_best_val(best_val_acc, val_acc)
        best_test_acc = get_best_val(best_test_acc, test_acc)
        if args.print:
            if args.dataset.loss_type == 'torch.nn.MSELoss':
                log(Epoch=epoch, 
                    Loss=loss, 
                    Train=train_acc, 
                    Val=val_acc, 
                    Test=test_acc,) 
            else:
                log(Epoch=epoch, 
                    train_loss=train_loss,
                    val_loss=val_loss,
                    test_loss=test_loss,
                    train_acc=train_acc,
                    test_acc=test_acc,
                    KL_Loss=kl_loss)

        if args.use_wandb:
            if args.dataset.loss_type == 'torch.nn.MSELoss':
                wandb.log({
                    'Loss/train': train_loss, 
                    'Loss/val': val_loss, 
                    'Loss/test': test_loss,
                    'epoch': epoch,
                })
                log_gradients(model)
                log_statistics(model)
            else:
                wandb.log({
                    'Loss/train': train_loss, 
                    'Loss/val': val_loss, 
                    'Loss/test': test_loss,
                    'Loss/kl_loss': kl_loss,
                    'Loss/test_acc':test_acc,
                    'Loss/train_acc':train_acc,
                    'epoch': epoch,
                })
                log_gradients(model)
                log_statistics(model)
        if args.log_model:
            if epoch % 100 == 0:
                ckpt_path = os.path.join(args.workdir.root, args.workdir.name)
                torch.save(
                    model, 
                    os.path.join(ckpt_path, f'checkpoint_{epoch:04}.pth'))
    if args.log_model:
        torch.save(model, args.workdir.checkpoint_path)
    return model

if __name__=='__main__':
    model = main()
