from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import os
import shutil
import torch


def init(args, deterministic=True):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    torch.backends.cudnn.deterministic = deterministic

    if torch.cuda.is_available():
        args.device_type = 'cuda'
        _device = 'cuda'
        if args.gpu is not None:
            _device += f':{args.gpu}'
    else:
        args.device_type = 'cpu'
        _device = 'cpu'

    args.device = torch.device(_device)
    torch.set_default_device(_device)


def init_log(args, fpath='log.txt'):
    """Print the "hello"-text and save it to the log file."""
    dt = datetime.now().strftime('%Y-%m-%d %H-%M-%S')
    content = f'Start computations'
    text_out = f'[{dt}] >> {content}'
    text_out += '\n' + '=' * 24 + ' ' + '-' * len(content) + '\n'

    text = ''
    text += f'Name                   : {args.name}\n'
    text += f'Mode                   : {args.mode}\n'
    text += f'Seed                   : {args.seed}\n'
    text += f'Device                 : {args.device}\n'
    text += f'Epochs                 : {args.epochs}'
    if args.mode == 'mt':
        text += '\n'
        text += f'Number of heads        : {args.d}\n'
    if args.mode == 'tt' or args.mode == 'lr2':
        text += '\n'
        text += f'TT-dimension           : {args.d}\n'
        text += f'TT-rank                : {args.r}\n'

    text_out += text
    text_out += '\n' + '=' * (25 + len(content)) + '\n'
    
    log(text_out, 'ini', fpath)


def init_path(name, root='result', rewrite=False):
    os.makedirs(root, exist_ok=True)
    fold = f'{root}/{name}'
    if os.path.isdir(fold):
        if rewrite:
            act = 'y'
        else:
            msg = f'Path "{fold}" already exists. Remove? [y/n] '
            act = input(msg)
        if act == 'y':
            shutil.rmtree(fold)
        else:
            raise ValueError('Folder with results is already exists')
    os.makedirs(fold) # , exist_ok=True)


def log(text, kind='', fpath='log.txt'):
    """Print the text and save it to the log file."""
    pref = ''
    
    #if kind != 'ini' and kind != 'log':
    #    pref += '[' + datetime.now().strftime('%H-%M-%S') + '] > '
    
    if kind == 'prc':
        pref = '... ' + pref
    if kind == 'res':
        pref = '+++ ' + pref
    if kind == 'wrn':
        pref = 'WRN ' + pref
    if kind == 'err':
        pref = '!!! ' + pref
    
    text = pref + text
    with open(fpath, 'w' if kind == 'ini' else 'a+', encoding='utf-8') as f:
        f.write(text + '\n')
    print(text)


def plot_loss(losses, fpath=None, is_batch=False):
    """Visualization of the loss function values while training."""
    fig = plt.figure(figsize=(6, 6))
    ax = fig.gca()
    ax.set_xlabel('Batch number' if is_batch else 'Epoch')
    ax.set_ylabel('Loss')
    ax.plot(np.arange(len(losses))+1, losses, marker='o', markersize=8)
    ax.semilogy()
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    if fpath:
        plt.savefig(fpath, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()