from ast import arg
import logging
import random
import os
from os.path import join
import time
import yaml
import json
import shutil
from copy import deepcopy
import argparse
from collections import defaultdict
import numpy as np
import torch
import math
from torch.optim import Optimizer
import torch.distributed as dist
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

import torch.nn as nn
from easydict import EasyDict as edict


def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def setup_device(args):
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.n_gpu = torch.cuda.device_count()
    print("---------", args.device)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count if self.count != 0 else 0

        
class Lookahead(Optimizer):
    def __init__(self, base_optimizer, alpha=0.5, k=6):
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.defaults = base_optimizer.defaults
        self.defaults.update(defaults)
        self.state = defaultdict(dict)
        # manually add our defaults to the param groups
        for name, default in defaults.items():
            for group in self.param_groups:
                group.setdefault(name, default)

    def update_slow(self, group):
        for fast_p in group["params"]:
            if fast_p.grad is None:
                continue
            param_state = self.state[fast_p]
            if 'slow_buffer' not in param_state:
                param_state['slow_buffer'] = torch.empty_like(fast_p.data)
                param_state['slow_buffer'].copy_(fast_p.data)
            slow = param_state['slow_buffer']
            slow.add_(group['lookahead_alpha'], fast_p.data - slow)
            fast_p.data.copy_(slow)

    def sync_lookahead(self):
        for group in self.param_groups:
            self.update_slow(group)

    def step(self, closure=None):
        #assert id(self.param_groups) == id(self.base_optimizer.param_groups)
        loss = self.base_optimizer.step(closure)
        for group in self.param_groups:
            group['lookahead_step'] += 1
            if group['lookahead_step'] % group['lookahead_k'] == 0:
                self.update_slow(group)
        return loss

    def state_dict(self):
        fast_state_dict = self.base_optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict['state']
        param_groups = fast_state_dict['param_groups']
        return {
            'state': fast_state,
            'slow_state': slow_state,
            'param_groups': param_groups,
        }

    def load_state_dict(self, state_dict):
        fast_state_dict = {
            'state': state_dict['state'],
            'param_groups': state_dict['param_groups'],
        }
        self.base_optimizer.load_state_dict(fast_state_dict)

        # We want to restore the slow state, but share param_groups reference
        # with base_optimizer. This is a bit redundant but least code
        slow_state_new = False
        if 'slow_state' not in state_dict:
            print('Loading state_dict from optimizer without Lookahead applied.')
            state_dict['slow_state'] = defaultdict(dict)
            slow_state_new = True
        slow_state_dict = {
            'state': state_dict['slow_state'],
            'param_groups': state_dict['param_groups'],  # this is pointless but saves code
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.param_groups = self.base_optimizer.param_groups  # make both ref same container
        if slow_state_new:
            # reapply defaults to catch missing lookahead specific ones
            for name, default in self.defaults.items():
                for group in self.param_groups:
                    group.setdefault(name, default)



class EarlyStopping(object):
    def __init__(self, patience):
        super(EarlyStopping, self).__init__()
        self.patience = patience
        self.counter = 0
        self.best_score = None
        
    def state_dict(self):
        return {
            'best_score': self.best_score,
            'counter': self.counter
        }
        
    def load_state_dict(self, state_dict):
        self.best_score = state_dict['best_score']
        self.counter = state_dict['counter']

    def __call__(self, score):
        is_save, is_terminate = True, False
        if self.best_score is None:
            self.best_score = score
        elif self.best_score >= score:
            self.counter += 1
            if self.counter >= self.patience:
                is_terminate = True
            is_save = False
        else:
            self.best_score = score
            self.counter = 0
        return is_save, is_terminate

class ModelEma(torch.nn.Module):
    def __init__(self, model, decay=0.9997, device=None):
        super(ModelEma, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

    def forward(self, x, inference=False, kfold=False):
        return self.module(x, inference=inference, kfold=kfold)
    
def get_experiment_id(exp_home):
    exp_names = [t for t in os.listdir(exp_home) if t[-1].isdigit()]
    if len(exp_names) == 0:
        new_exp_id = 1
    else:
        exp_ids = [int(en[3:]) for en in exp_names]
        new_exp_id = max(exp_ids) + 1
    return new_exp_id

def check_makedir(dir):
    print("dir:", dir)
    if not os.path.exists(dir):
        print("makedir:", dir)
        os.makedirs(dir)
        
def check_exists(filepath):
    if not os.path.exists(filepath):
        raise FileNotFoundError('file {} not found!'.format(filepath))

def load_yaml(path):
    '''
    load [.yaml] files
    '''
    file = open(path, 'r')
    yaml_obj = yaml.load(file.read(), Loader=yaml.FullLoader)
    return yaml_obj

def prepare_env(args, argv):
    setup_device(args)

    config = edict(load_yaml(args.cfg))
    config.model.version = args.version if args.version else config.model.version
    if config.train.ddp.istrue:
        local_rank = int(os.environ["LOCAL_RANK"])
        config.train.ddp.local_rank = local_rank
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')
    
    # prepare checkpoint and log config
    dataset = sorted(config.train.dataset.which_use)
    dataset = ",".join(dataset)
    exp_home = join(config.common.exp_home, '{}_{}'.format(config.model.name, dataset))
    if not config.train.ddp.istrue or dist.get_rank()==0:
        check_makedir(exp_home)
    
    if config.train.resume:
        exp_name = 'exp{}_{}'.format(get_experiment_id(exp_home), 'resume')
    else:
        exp_name = 'exp{}'.format(get_experiment_id(exp_home))
    exp_dir = join(exp_home, exp_name)
    config.common.exp_dir = exp_dir
    config.common.log_path = join(exp_dir, 'train.log')
    config.common.ckpt_dir = join(exp_dir, 'checkpoints')
    if not config.train.ddp.istrue or dist.get_rank()==0:
        check_makedir(config.common.exp_dir)
        check_makedir(config.common.ckpt_dir)
    
    
    if not config.train.ddp.istrue or dist.get_rank()==0:
        # save experiment checkpoint
        exp_ckpt_path = os.path.join(exp_home, 'checkpoint.txt')
        temp = ' '.join(['python', *argv])
        with open(exp_ckpt_path, 'a') as fa:
            fa.writelines('{}\t{}\n'.format(exp_name, temp))
    
        # save config
        cfg_path = join(config.common.exp_dir, 'config.yaml')
        config_ = json.loads(json.dumps(config))
        with open(cfg_path, 'w', encoding='utf-8') as f:
            yaml.dump(config_, f)
            
        # prepare log
        log_path = join(exp_dir, 'train.log')
        prepare_log(log_path, config)
            
    return config

def prepare_val_env(args, argv):
    setup_device(args)

    # like'/nfs/users/chenxiangmei/track by language/TransT-self/experiments/transt_lasot/exp1/checkpoints/model_epoch_999.bin'
    load_path = args.test_model
    file_split = load_path.split('/')
    model_name = file_split[-1]
    exp_dir = load_path.replace('/'+file_split[-1], '').replace('/'+file_split[-2], '')

    config_file = os.path.join(exp_dir, 'config.yaml')
    print(config_file)
    config = edict(load_yaml(config_file))

    config.test.data = args.data
    datasets = config.test.data.split(',')
    datasets_ = '_'.join(datasets)

    test_dir = os.path.join(config.common.exp_dir, 'test')
    model_dir = os.path.join(test_dir, model_name)
    check_makedir(test_dir)
    check_makedir(model_dir)
    config.common.test_model = load_path
    
    for data in datasets:
        if args.length!=-1:
            check_makedir(os.path.join(model_dir, data+'_'+str(args.length)))
        else:
            check_makedir(os.path.join(model_dir, data))
    
    # save config
    cfg_path = join(model_dir, 'config.yaml')
    config_ = json.loads(json.dumps(config))
    with open(cfg_path, 'w', encoding='utf-8') as f:
        yaml.dump(config_, f)

    if args.length!=-1:
        log_path = join(model_dir, datasets_+'_'+str(args.length)+'.log')
    else:
        log_path = join(model_dir, datasets_+'.log')
    if os.path.exists(log_path):
        os.remove(log_path)
        
    prepare_val_log(log_path, config)

    config.test.threads = args.threads

    return config

def prepare_log(log_path, cfg, level=logging.INFO):
    "cfg is dict"
    config_ = json.dumps(cfg, indent=1, separators=(', ', ': '), ensure_ascii=False)

    logger = logging.getLogger()
    logger.setLevel(level)
    sh = logging.StreamHandler()
    th = logging.FileHandler(filename=log_path, encoding='utf-8')
    logger.addHandler(sh)
    logger.addHandler(th)
    
    logger.info('model training time: {}'.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
    logger.info('model configuration: ')
    format_string = config_
    logger.info(format_string)

def prepare_val_log(log_path, cfg, level=logging.INFO):
    logger = logging.getLogger()
    logger.setLevel(level)
    sh = logging.StreamHandler()
    th = logging.FileHandler(filename=log_path, encoding='utf-8')
    logger.addHandler(sh)
    logger.addHandler(th)
    
    logger.info('model verifying time: {}'.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
    logger.info('load model file: {}'.format(cfg.common.test_model))
    logger.info('load datasets for verifying: {}'.format(cfg.test.data))
    
def clear_exp(exp_dir):
    logging.shutdown()
    shutil.rmtree(exp_dir)
    exp_home = os.path.dirname(exp_dir)
    exp_ckpt_path = os.path.join(exp_home, 'checkpoint.txt')
    with open(exp_ckpt_path, 'r') as fr:
        temp = fr.readlines()[:-1]
    with open(exp_ckpt_path, 'w') as fw:
        fw.writelines(temp)

def evaluate(predictions, labels):
    
    eval_results = {'accuracy': accuracy_score(labels, predictions),
                    'p': precision_score(labels, predictions),
                    'r': recall_score(labels, predictions),
                    'f1': f1_score(labels, predictions)
                    }

    return eval_results


def seconds_to_dhms(seconds):
    def _days(day):
        return "{:0>2d}:".format(day) if day > 1 else "{:0>2d}:".format(day)
    def _hours(hour):  
        return "{:0>2d}:".format(hour) if hour > 1 else "{:0>2d}:".format(hour)
    def _minutes(minute):
        return "{:0>2d}:".format(minute) if minute > 1 else "{:0>2d}:".format(minute)
    def _seconds(second):  
        return "{:0>2d}".format(second) if second > 1 else "{:0>2d}".format(second)          
    days = seconds // (3600 * 24)
    hours = (seconds // 3600) % 24
    minutes = (seconds // 60) % 60
    seconds = seconds % 60
    if days > 0 :
        return _days(days)+_hours(hours)+_minutes(minutes)+_seconds(seconds)
    if hours > 0 :
        return _hours(hours)+_minutes(minutes)+_seconds(seconds)
    if minutes > 0 :
        return _minutes(minutes)+_seconds(seconds)
    return _seconds(seconds)


def check_trainable(model, logger, print=True):
    """
    print trainable params
    """
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    logger.info('trainable params:')
    if print:
        for name, param in model.named_parameters():
            if param.requires_grad:
                logger.info(name)

    assert len(trainable_params) > 0, 'no trainable parameters'

    return trainable_params

def is_valid_number(x):
    return not(math.isnan(x) or math.isinf(x) or x > 1e4)