import os
import gc
import sys
from turtle import st
import git
import time
import math
import json
import socket
import random
import datetime
import pathlib
import subprocess
from collections import defaultdict, deque
import GPUtil
import numpy as np
import torch
import pickle
from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from oauth2client.file import Storage
from oauth2client import client, tools


def init_wandb(args):
    import wandb
    # wandb_path = os.path.join(args.wandb_dir, 'ec', args.exp, 'analysis')
    legend = args.exp
    args.output_path = get_path(args.output_dir, args.exp)
    if len(args.exp_load) > 0: args.load_path = get_path(args.output_dir,
                                                         args.exp_load)
    args.chkpt_path = os.path.join(args.output_path, 'chkpt_final.pth')
    wandb_path= get_path(args.wandb_dir, args.exp)
    wandb.init(project=args.proj_name, name=legend, tags=args.tags,
               dir=wandb_path, entity=args.wandb_entity)
    wandb.define_metric("val/step")
    wandb.define_metric("val/*", step_metric="val/step")
    wandb.config.update(args, allow_val_change=True)
    # print("git:\n  {}\n".format(get_sha()))
    print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    write_exp_json(args)


def get_path(dir, exp):
    path = os.path.join(dir, exp)
    pathlib.Path(path).mkdir(parents=True, exist_ok=True)
    return path


def write_exp_json(args):
    dt = dict()
    dt['ip'] = socket.gethostbyname(socket.gethostname())
    dt['st_time'] = args.st_time
    dt['sha'] = args.sha
    dt['seed'] = args.seed
    dt['exp'] = args.exp
    dt['args'] = vars(args)
    dt_path = os.path.join(args.output_path, 'dt.json')
    with open(dt_path,'w') as f:
        json.dump(dt, f)


def write_train_done(args, epoch):
    done_path = os.path.join(args.output_path, 'train_done.txt')
    with open(done_path, 'w') as f:
        f.write('{}, {}'.format(args.epochs, epoch+1))


ks = ['method', 'arch', 'epochs', 'optimizer', 'lr', 'lr_decay_epochs',
      'lr_decay_rate', 'warmup_iters', 'bsz', 'dataset', 'act', 'eta_r',
      'run_cond', 'n_layers', 'latent_dim', 'step_exp', 'not_prg', 'skip_bias', 
      'last_cls', 'z_norm', 'alpha_mean', 'w_norm', 'w_reg', 'reg_coef', 'z_init', 'comp_eta', 'orthogonal_testing']


def get_config(args):
    args_dt = vars(args)
    keys = args_dt.keys()
    s = ''
    for i, k in enumerate(ks):
        if i > 0:
            s += ','
        if k in keys:
            d = str(args_dt[k])
            d = d.replace(' ', '')
            if 'True' == d:
                d = 'T'
            elif 'False' == d:
                d = 'F'
            s += '{}:{}'.format(k[0], d)
        else:
            print(k)
            raise ValueError
    s.replace(' ', '')
    return s


def set_exp(args):
    not_test_cond = 'test' not in args.exp  and 'tmp' not in args.exp
    if not_test_cond:
        s = '{}_{}_{}_sw{}_sb{}_et{}_T{}_L{}_{}_orth{}'\
                    .format(args.dataset, args.arch, args.method, args.min_val_sw, args.sigma_b, 
                            args.min_val_e, args.T, args.n_layers, args.z_init, args.orthogonal_testing)
        s = s.replace('cifar','cf').replace('svhn', 'sv').replace('resnet', 'rs').replace('pcd', 'spcn').replace('pc', 'pcn')
        # if not args.nosha: s+= '_' + args.sha
        # s += '_' + str(args.seed)
        # a = sys.argv[1:]
        # args_dt = dict()
        # for i, ai in enumerate(a):
        #     if '--' in ai:
        #         k = ai.replace('--', '')
        #         if k == 'nosha': continue
        #         if i+1 < len(a) and '--' not in a[i+1]:
        #             v = a[i+1]
        #         else:
        #             v = ''
        #         if k not in ['dataset', 'arch', 'head']:
        #             args_dt[k] = v
        # for k in sorted(args_dt.keys()):
        #     if len(args_dt[k]) > 0:
        #         s += ',{}:{}'.format(k, args_dt[k])
        #     else:
        #         s += ',{}'.format(k)
        exp_name = s + '_' + args.st_time
        args.exp = args.exp + '_' + exp_name # if len(args.exp) > 0 else exp_name


def set_fig_dir(args):
    args.fig_dir = os.path.join('figures/', args.exp)
    pathlib.Path(args.fig_dir).mkdir(parents=True, exist_ok=True)
    # args.log_root = os.path.join('log/', args.exp)
    # pathlib.Path(args.log_dir).mkdir(parents=True, exist_ok=True)


def set_seed(seed):
    if seed >= 0:
        import numpy as np
        import random
        import torch
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)
        torch.backends.cudnn.deterministic = True
        os.environ['PYTHONHASHSEED'] = str(seed)


def load_pickle(args, log_path):
    with open(log_path, 'rb') as f:
        log = pickle.load(f)
        # use this when dumping class
        log.clip_T(args.T)
        # use this when dumping list
        # log = [l[:args.T] if i<7 else l for i, l in enumerate(log)]s
        len_lst = log.get_len_lst()
        # delete log to save memory
        del log
        gc.collect()
    return len_lst


def get_sha():
    cwd = os.path.dirname(os.path.abspath(__file__))

    def _run(command):
        return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
    sha = 'N/A'
    diff = "clean"
    branch = 'N/A'
    try:
        sha = _run(['git', 'rev-parse', 'HEAD'])
        subprocess.check_output(['git', 'diff'], cwd=cwd)
        diff = _run(['git', 'diff-index', 'HEAD'])
        diff = "has uncommited changes" if diff else "clean"
        branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
    except Exception:
        pass
    message = f"sha: {sha}, status: {diff}, branch: {branch}"
    return message


def upload_csv(args, desc):
    upload_path = desc.replace('   ', '_')

    SCOPES = [
        'https://www.googleapis.com/auth/drive',
    ]
    parents_ids = {
        'cifar10' : '1sGHhK3SLdHAm8Sebrjx4yg7el8bdIJSg',
        'cifar100' : "1T5Btmk7LyWp3aMkJtoaoatdKC1IKfPPT"
    }

    if os.path.exists('js_file/token.json'):
        creds = Credentials.from_authorized_user_file('js_file/token.json', SCOPES)
    else:
        flow = InstalledAppFlow.from_client_secrets_file(
            'js_file/client.json', SCOPES
        )
        creds = flow.run_local_server(port=0)
        with open('token.json', 'w') as token:
            token.write(creds.to_json())
    drive_service = build('drive', 'v3', credentials=creds)

    # mkdir
    if len(args.upload_dir) == 0:
        args.upload_dir = 'test'
    request_body = {
        'name': args.upload_dir,
        'parents': [parents_ids[args.dataset]],
        'mimeType': 'application/vnd.google-apps.folder'
    }
    folder_id = drive_service.files().create(body=request_body, fields='id').execute()
    folder_id = folder_id.get('id')
    print("Folder Link:", f'https://drive.google.com/drive/folders/{folder_id}')
    request_body = {
        'name': f'{upload_path}_{args.start_time}_{get_current_git_hash()}',
        'parents': [folder_id],
        'mimeType': 'application/vnd.google-apps.folder'
    }
    folder_id = drive_service.files().create(body=request_body, fields='id').execute()
    folder_id = folder_id.get('id')

    for amp_s in ['max', 'top1', 'top2', 'all']:
        file_path = os.path.join(args.output_path, f'{amp_s}.csv')
        request_body = {
            'name': f'{amp_s}.csv', # filename
            'parents': [folder_id]
        }
        media = MediaFileUpload(file_path, mimetype='text/csv')
        file_info = drive_service.files().create(
            body=request_body, media_body=media, fields='id,webViewLink').execute()
        # webviewlink = file_info.get('webViewLink')
        # print(f"Uploaded file to Google Drive: {webviewlink}")

    file_path = os.path.join(args.output_path, 'calibration.csv')
    if os.path.exists(file_path):
        request_body = {
                'name': f'calibration.csv', # filename
                'parents': [folder_id]
            }
        media = MediaFileUpload(file_path, mimetype='text/csv')
        file_info = drive_service.files().create(
            body=request_body, media_body=media, fields='id,webViewLink').execute()

    file_path = os.path.join(args.output_path, 'inex.csv')
    if os.path.exists(file_path):
        request_body = {
                'name': f'inex.csv', # filename
                'parents': [folder_id]
            }
        media = MediaFileUpload(file_path, mimetype='text/csv')
        file_info = drive_service.files().create(
            body=request_body, media_body=media, fields='id,webViewLink').execute()


def get_valid_unit(x):
    import math
    if x == 0.0:
        return x
    # num_digits = math.floor(math.log10(abs(x))) + 1
    result = round(x, 3)
    return result

