import argparse
import math
from typing import Dict
import logging
import torch
from torch import optim
import numpy as np
import os
import random

from tqdm import tqdm
from datasets import *
from optimizers import *
from models import *
from regularizers import N3, Lambda3
from ne import NE
from utils import *
from tasks import *


def get_model(sizes):
    if args.model == 'search':
        return None
    if args.load_name is not None:
        model = torch.load(args.load_name + '/model.pth')
        logger.info(f'Loaded from {args.load_name}')
    else:
        model = {
            'ComplEx': ComplEx(sizes, args.rank),
            'TComplEx': TComplEx(sizes, args.rank, no_time_emb=args.no_time_emb),
            'TNTComplEx': TNTComplEx(sizes, args.rank, no_time_emb=args.no_time_emb),
            'DE_SimplE': DE_SimplE(sizes, args.rank, args.se_prop, args.dropout, args.neg_ratio, no_time_emb=args.no_time_emb),
            'DE_DistMult': DE_DistMult(sizes, args.rank, args.se_prop, args.dropout, args.neg_ratio, no_time_emb=args.no_time_emb),
            'DE_TransE': DE_TransE(sizes, args.rank, args.se_prop, args.dropout, args.neg_ratio, no_time_emb=args.no_time_emb),
            'TeRo': TeRo(sizes, args.rank, args.gamma, args.n_day, args.neg_ratio),
            'ATISE': ATISE(sizes, args.rank, args.gamma, args.n_day, args.neg_ratio, args.cmin),
            'NE': NE(sizes, args.rank, no_time_emb=args.no_time_emb, lbl_cnt=args.cluster_num, \
                            score_mode=args.score_mode, curve_mode=args.curve_mode, \
                            omg_max=args.omg_max, cluster_trainable=args.cluster_trainable)
        }[args.model]
        model = model.cuda()
    return model


def train_model(model, dataset, train_data, train_time_range=None):
    sizes = dataset.get_shape()
    logger.info(f"sizes:{sizes}")
    if train_data is None or train_data.shape[0] == 0:
        return
    if train_time_range is None:
        train_time_range = (train_data[:, 3].min(), train_data[:, 3].max() + 1)
        logger.info(f"time range:{train_time_range}")
    opt = optim.Adagrad(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(opt, math.ceil(args.max_epochs / (math.log(0.1) / math.log(0.9))), 0.9, verbose = False)
    emb_reg = N3(args.emb_reg)
    time_reg = Lambda3(args.time_reg)
    time_loss = LocalBlurLoss(sizes[3], train_time_range, args.blur_range)
    np.random.shuffle(train_data)
    examples = torch.from_numpy(train_data)

    for epoch in range(args.max_epochs):
        logger.info(f'epoch: {epoch}')
        model.train()
        if dataset.has_intervals():
            optimizer = IKBCOptimizer(
                model, emb_reg, time_reg, opt, dataset,
                batch_size=args.batch_size
            )
            optimizer.epoch(examples)
        else:
            if args.model in ['NE']:
                optimizer = MyOptimizer(
                    model, emb_reg, opt,
                    time_loss = time_loss,
                    batch_size = args.batch_size
                )
            elif args.model in ['DE_SimplE', 'DE_DistMult', 'DE_TransE']:
                optimizer = DEOptimizer(
                    model, opt, 
                    batch_size=args.batch_size
                )
            elif args.model in ['TeRo', 'ATISE']:
                optimizer = TEOptimizer(
                    args.model, model, opt, 
                    batch_size=args.batch_size
                )
            else:
                optimizer = TKBCOptimizer(
                    model, emb_reg, time_reg, opt,
                    batch_size=args.batch_size
                )
            if args.model == 'NE':
                loss_cur = optimizer.epoch(examples, train_time_range)
            else:
                loss_cur = optimizer.epoch(examples)
        scheduler.step()

        if epoch < 0 or (args.valid_freq > 0 and (epoch + 1) % args.valid_freq == 0):
            eval_model(model, dataset)


def train_online(model, dataset, train_data):
    sizes = dataset.get_shape()
    time_max = train_data[:, 3].max() + 1
    online_interval = args.online_interval
    for t in tqdm(range(0, time_max, online_interval)):
        logger.info(f'Interval: [{t}, {t + online_interval})')
        idx_list = (t <= train_data[:, 3]) * (train_data[:, 3] < t + online_interval)
        l, r = t, min(t + online_interval, time_max)
        train_model(model, dataset, train_data[idx_list], train_time_range=(l, r))


def evaluate_taq(model, dataset, online=False):
    taq = TemporalAssociationQuery(dataset, query_count=args.query_count, pair_freq=args.query_pair_freq)

    if args.run_mode == 'train':
        train_data = taq.train_data

        if online:
            train_online(model, dataset, train_data)
        else:
            train_model(model, dataset, taq.train_data)
        torch.save(model, save_dir + '/model.pth')
    
    accuracy, speedup_cpu, speedup_gpu = taq.evaluate(model)
    logger.info(f"Acc: {accuracy}, Speedup on CPU: {speedup_cpu}, Speedup on GPU: {speedup_gpu}")


def evaluate_tam(model, dataset, online=False):
    tam = TemporalAssociationMining(dataset.get_shape(), dataset)
    
    train_data = tam.train_data

    if args.run_mode == 'train' and args.model != 'search':
        if online:
            train_online(model, dataset, train_data)
        else:
            train_model(model, dataset, train_data)
        torch.save(model, save_dir + '/model.pth')
    if online:
        last_idx = (train_data[: , 3] >= (args.online_iteration)) * (train_data[: , 3] < (args.online_iteration + args.online_interval))
        train_data = train_data[last_idx]

    tri_to_freq = count_triples(sizes, train_data)['tri_to_freq']
    
    if args.model == 'search':
        tam.run_search(train_data)
    else:
        tam.evaluate_model(model, tri_to_freq, train_data)


if __name__ == '__main__':
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    dataset = TemporalDataset(args.dataset)
    sizes = dataset.get_shape()
    model = get_model(sizes)
    
    if args.task == 'taq':
        evaluate_taq(model, dataset, online=False)
    elif args.task == 'taq_online':
        evaluate_taq(model, dataset, online=True)
    elif args.task == 'tam':
        evaluate_tam(model, dataset, online=False)
    elif args.task == 'tam_online':
        evaluate_tam(model, dataset, online=True)