from __future__ import division
from __future__ import print_function

import os
import json
import time
import wandb
import pickle
import logging
import datetime
import importlib
from pathlib import Path
from config import parser

import torch
import numpy as np

from datasets.utils import preprocess_data
from train_utils import get_dir_name, format_metrics


def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if int(args.double_precision):
        torch.set_default_dtype(torch.float64)
    if int(args.cuda) >= 0:
        torch.cuda.manual_seed(args.seed)
    args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu'
    args.patience = args.epochs if not args.patience else int(args.patience)
    logging.getLogger().setLevel(logging.INFO)
    if args.save:
        if not args.save_dir:
            dt = datetime.datetime.now()
            date = f"{dt.year}_{dt.month}_{dt.day}"
            models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date)
            save_dir = get_dir_name(models_dir)
        else:
            save_dir = args.save_dir
        Path(save_dir).mkdir(parents=True, exist_ok=True)
        logging.basicConfig(
            level=logging.INFO,
            handlers=[
                logging.FileHandler(os.path.join(save_dir, 'log.txt')),
                logging.StreamHandler()
            ]
        )

    logging.info(f'Using: {args.device}')
    logging.info("Using seed {}.".format(args.seed))

    # Load data
    data_module = importlib.import_module(f'datasets.{args.dataset}')
    encoder_module = importlib.import_module(f'encoders.{args.model}')
    task_module = importlib.import_module(f'tasks.{args.task}')

    data = getattr(data_module, f'load_{args.task}_data')(args, os.path.join(os.environ['DATAPATH']))
    postprocess_fn = getattr(encoder_module, 'postprocess_fn')
    data = preprocess_data(args, data, postprocess_fn)
    
    if args.model in ['TokenGT', 'FPST']:
        args.n_nodes, _, args.feat_dim = data['features'].shape
    else:
        args.n_nodes, args.feat_dim = data['features'].shape
    
    if args.task == 'nc':
        args.n_classes = int(data['labels'].max() + 1)
        logging.info(f'Num classes: {args.n_classes}')
    elif args.task == 'md':
        args.eval_freq = args.epochs + 1

    if not args.lr_reduce_freq:
        args.lr_reduce_freq = args.epochs

    # Model and optimizer
    encoder = getattr(encoder_module, 'Encoder')(args)
    model = getattr(task_module, 'TaskModel')(args, encoder)
    logging.info(str(model))
    if args.model == 'FPST':
        curvature_lr = args.curvature_lr
        curvatures = [param for name, param in model.named_parameters() if 'c.' in name]
        non_curvatures = [param for name, param in model.named_parameters() if 'c.' not in name]
        params = [{"params": curvatures, "lr": curvature_lr, 'weight_decay': 0}, 
                  {"params": non_curvatures}]
        
        optimizer = getattr(torch.optim, args.optimizer)(params, lr = args.lr, weight_decay = args.weight_decay)
    else:
        optimizer = getattr(torch.optim, args.optimizer)(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=int(args.lr_reduce_freq),
        gamma=float(args.gamma)
    )
    
    tot_params = sum([np.prod(p.size()) for p in model.parameters()])
    logging.info(f"Total number of parameters: {tot_params}")
    if args.cuda is not None and int(args.cuda) >= 0:
        model = model.to(args.device)
        for x, val in data.items():
            if torch.is_tensor(data[x]):
                data[x] = data[x].to(args.device)
                
    wandb.init(project=args.project)
    wandb.run.name = args.exp_name
    wandb.config.update(args)
                
    # Train model
    t_total = time.time()
    counter = 0
    best_val_metrics = model.init_metric_dict()
    best_test_metrics = model.init_metric_dict()
    best_emb = None
    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        optimizer.zero_grad()
        
        embeddings = encoder(data)
        train_metrics = model.compute_metrics(embeddings, data, 'train')
        train_metrics['loss'].backward()
        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)

        optimizer.step()
        lr_scheduler.step()
        if (epoch + 1) % args.log_freq == 0:
            logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1),
                                   'lr: {}'.format(lr_scheduler.get_last_lr()[0]),
                                   format_metrics(train_metrics, 'train'),
                                   'time: {:.4f}s'.format(time.time() - t)
                                   ]))

            wandb.log({
                'epoch': epoch,
                **{'train_'+metric_name: metric_value for metric_name,metric_value in train_metrics.items()},
            })
            
        if args.task == "md" and model.has_improved(best_test_metrics, train_metrics):
            best_test_metrics = model.compute_metrics(embeddings, data, 'test')
            
        if (epoch + 1) % args.eval_freq == 0:
            model.eval()
            with torch.no_grad():
                embeddings = encoder(data)
                
                val_metrics = model.compute_metrics(embeddings, data, 'val')
                if (epoch + 1) % args.log_freq == 0:
                    logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val')]))
                    
                wandb.log({
                    'epoch': epoch,
                    **{'val_'+metric_name: metric_value for metric_name,metric_value in val_metrics.items()},
                })
                    
                if model.has_improved(best_val_metrics, val_metrics):
                    best_test_metrics = model.compute_metrics(embeddings, data, 'test')
                    best_emb = embeddings.cpu()
                    if args.save:
                        np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy())
                    best_val_metrics = val_metrics
                    counter = 0
                else:
                    counter += 1
                    if counter == args.patience and epoch > args.min_epochs:
                        logging.info("Early stopping")
                        break

    logging.info("Optimization Finished!")
    logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    if not best_test_metrics:
        model.eval()
        with torch.no_grad():
            best_emb = encoder(data)
            best_test_metrics = model.compute_metrics(best_emb, data, 'test')
        
    logging.info(" ".join(["Val set results:", format_metrics(best_val_metrics, 'val')]))
    logging.info(" ".join(["Test set results:", format_metrics(best_test_metrics, 'test')]))
    
    wandb.log({
        'epoch': epoch,
        **{'test_'+metric_name: metric_value for metric_name,metric_value in best_test_metrics.items()},
        **{'best_val_'+metric_name: metric_value for metric_name,metric_value in best_val_metrics.items()},
    })
    
    if args.save:
        np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.cpu().detach().numpy())
        if hasattr(model.encoder, 'att_adj'):
            filename = os.path.join(save_dir, args.dataset + '_att_adj.p')
            pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb'))
            print('Dumped attention adj: ' + filename)

        json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w'))
        torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
        logging.info(f"Saved model in {save_dir}")

if __name__ == '__main__':
    args = parser.parse_args()
    
    if args.save == 1 and args.task == 'md':
        args.save_dir = f"logs/md/{args.dataset}/{args.model}/{args.dim}"
        args.head_dim = args.dim
        
    train(args)
