import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import random
import matplotlib.pyplot as plt
import pickle
import numpy as np
from torch.autograd import Variable
from datetime import datetime
import time

def write_log(args, message):
    timestamp = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | {seconds_to_hms(time.time()-args.start_time)}] "
    with open(args.log_file, 'a') as f:
        f.write(timestamp + message + '\n')

def seconds_to_hms(seconds):
    hours = seconds // 3600
    minutes = (seconds % 3600) // 60
    seconds = seconds % 60
    return "{:02}:{:02}:{:02}".format(int(hours), int(minutes), int(seconds))

def init_random_seed(manual_seed):
    """Init random seed."""
    seed = None
    if manual_seed is None:
        seed = random.randint(1, 10000)
    else:
        seed = manual_seed
    #print("use random seed: {}".format(seed))
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def init_model(net, restore):
    # restore model weights
    if restore is not None and os.path.exists(restore):
        net.load_state_dict(torch.load(restore))
        net.restored = True
        #print("Restore model from: {}".format(os.path.abspath(restore)))

    # check if cuda is available
    if torch.cuda.is_available():
        cudnn.benchmark = True
        net.cuda()

    return net

def save_model(args, net, filename):
    """Save trained model."""
    if not os.path.exists(args.model_root):
        os.makedirs(args.model_root)
    torch.save(net.state_dict(), os.path.join(args.model_root, filename))
    print("save pretrained model to: {}".format(os.path.join(args.model_root, filename)))

def make_plot(args, src_acc, ul1_acc, ul2_acc, test_acc, lam1_list, lam2_list, ul1_pl_acc, ul2_pl_acc, mid=False):
    domains = args.domains_letter[args.dataset]   # [a, c, p, s]
    domain_order = []   # ex) [p, a, c, s]
    for i in args.order:   # ex) [2, 0, 1, 3]
        domain_order.append(domains[i])
    epochs = list(range(1, len(test_acc)+1))
    if args.dataset == 'PACS':
        eid = [71.82, 94.94, 67.18, 77.39, 84.64, 72.24, 75.52, 70.98, 64.01, 66.86, 72.83, 67.21]
        mcd_mixstyle = [67.87, 95.09, 64.03, 78.44, 89.81, 68.30, 75.80, 61.19, 50.81, 44.60, 53.23, 48.38]
    elif args.dataset == 'Digits':
        eid = [51.56, 37.29, 53.30, 97.12, 58.60, 69.05, 87.73, 60.89, 87.51, 92.39, 64.21, 70.89]
        mcd_mixstyle = [45.36, 24.74, 48.34, 96.93, 36.71, 56.61, 66.49, 42.86, 84.13, 75.02, 49.29, 67.29]
    elif args.dataset == 'OfficeHome':
        eid = [48.32, 59.11, 66.54, 47.53, 60.38, 61.29, 46.11, 47.27, 66.04, 53.28, 48.80, 69.00]
        mcd_mixstyle = [44.93, 55.18, 65.29, 45.79, 56.77, 58.95, 42.93, 42.01, 63.28, 52.18, 46.04, 69.75]
    eid_acc = [eid[args.order_num]] * len(epochs)
    mcd_mixstyle_acc = [mcd_mixstyle[args.order_num]] * len(epochs)
    src_acc = [acc * 100 for acc in src_acc]
    test_acc = [acc * 100 for acc in test_acc]
    ul1_acc = [acc * 100 for acc in ul1_acc]
    ul2_acc = [acc * 100 for acc in ul2_acc]
    ul1_pl_acc = [acc * 100 for acc in ul1_pl_acc]
    ul2_pl_acc = [acc * 100 for acc in ul2_pl_acc]

    test_final = sum(test_acc[-5:]) / 5
    test_final = round(test_final.item(),2)
    diff_final = round(test_final - eid[args.order_num], 2)

    fig, ax1 = plt.subplots()

    #ax1.plot(epochs, eid_acc, label='EID', linestyle='dashed', color='C0')
    #ax1.plot(epochs, src_acc, label='source_{}'.format(domain_order[0]), color='C1')
    ax1.plot(epochs, test_acc, label='Model accuracy on test domain ({})'.format(domain_order[3], test_final), color='C2')
    #ax1.plot(epochs, ul1_acc, label='ul1_{}'.format(domain_order[1]), color='C3')
    #ax1.plot(epochs, ul2_acc, label='ul2_{}'.format(domain_order[2]), color='C4')
    #ax1.plot(epochs, mcd_mixstyle_acc, label='MCD+Mixstyle', linestyle='dashed', color='C5')
    if args.pl == 'shot':
        ax1.plot(epochs, ul1_pl_acc, label='PL accuracy of unlabeled domain ({})'.format(domain_order[1]), color='C3')
        ax1.plot(epochs, ul2_pl_acc, label='PL accuracy of unlabeled domain ({})'.format(domain_order[2]), color='C0')

    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_ylim([60, 100])

    ax2 = ax1.twinx()
    ax2.plot(epochs, lam1_list, label='Lambda of unlabeled domain ({})'.format(domain_order[1]), linestyle='dotted', color='C3')
    ax2.plot(epochs, lam2_list, label='Lambda of unlabeled domain ({})'.format(domain_order[2]), linestyle='dotted', color='C0')
    ax2.set_ylabel('Lambda')
    ax2.set_ylim([0.1, 0.5])

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='lower right')

    #plt.title(diff_final)
    tick1 = [60, 70, 80, 90, 100]
    ax1.set_yticks(tick1)
    tick2 = [0.1, 0.2, 0.3, 0.4, 0.5]
    ax2.set_yticks(tick2)
    plt.grid(True)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    file_path = '{}/{}/{}.png'.format(args.output_dir, args.dataset, args.title)

    """
    if os.path.exists(file_path):
        i = 1
        while os.path.exists(f"{file_path[:-4]}({i}).png"):
            i += 1
        file_path = f"{file_path[:-4]}({i}).png"
        with open(file_path, 'w') as f:
            pass
    """
    
    plt.savefig(file_path)
    with open(f"./pickle/{args.title}_test_acc.pickle", "wb") as fw:
        pickle.dump(test_acc, fw)
    with open(f"./pickle/{args.title}_ul1_pl_acc.pickle", "wb") as fw:
        pickle.dump(ul1_pl_acc, fw)
    with open(f"./pickle/{args.title}_ul2_pl_acc.pickle", "wb") as fw:
        pickle.dump(ul2_pl_acc, fw)
    with open(f"./pickle/{args.title}_lam1_list.pickle", "wb") as fw:
        pickle.dump(lam1_list, fw)
    with open(f"./pickle/{args.title}_lam2_list.pickle", "wb") as fw:
        pickle.dump(lam2_list, fw)

def make_plot_pre(args, src_valid_loss_list, src_valid_acc_list, ul1_valid_acc_list, ul2_valid_acc_list, test_acc_list):
    domain_names = args.domains_letter[args.dataset]   # [a, c, p, s]
    domain_order = []   # ex) [p, a, c, s]
    for i in args.order:   # ex) [2, 0, 1, 3]
        domain_order.append(domain_names[i])

    epochs = list(range(1, len(src_valid_loss_list)+1))
    src_acc = [acc * 100 for acc in src_valid_acc_list]
    ul1_acc = [acc * 100 for acc in ul1_valid_acc_list]
    ul2_acc = [acc * 100 for acc in ul2_valid_acc_list]
    test_acc = [acc * 100 for acc in test_acc_list]
    
    avg_acc = []
    for i in range(len(src_acc)):
        avg = (ul1_acc[i] + ul2_acc[i] + test_acc[i]) / 3
        avg_acc.append(avg)

    ep0, acc0 = best_epoch(src_acc)
    ep1, acc1 = best_epoch(ul1_acc)
    ep2, acc2 = best_epoch(ul2_acc)
    ep3, acc3 = best_epoch(test_acc)
    ep4, acc4 = best_epoch(avg_acc)
    ep5, loss = best_loss(src_valid_loss_list)
    
    fig, ax1 = plt.subplots()

    ax1.plot(epochs, src_acc, label='{}_ep{}_{}'.format(domain_order[0], ep0, acc0))
    ax1.plot(epochs, ul1_acc, label='{}_ep{}_{}'.format(domain_order[1], ep1, acc1))
    ax1.plot(epochs, ul2_acc, label='{}_ep{}_{}'.format(domain_order[2], ep2, acc2))
    ax1.plot(epochs, test_acc, label='{}_ep{}_{}'.format(domain_order[3], ep3, acc3))
    ax1.plot(epochs, avg_acc, label='Avg_ep{}_{}'.format(ep4, acc4))
    
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')

    ax2 = ax1.twinx()
    ax2.plot(epochs, src_valid_loss_list, label='{}_loss_ep{}_{}'.format(domain_order[0], ep5, loss))

    ax2.set_ylabel('Loss')

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')

    title = '{}_ep{}_lr{}_sch_{}_wd{}_{}_ogcat_{}'.format(domain_order[0], args.epoch_pre, args.lr_pre, args.schedule_pre, args.wd_pre, args.augmentation, args.aug_ogcat)
    
    if args.classifier_wn:
        title += '_wn'
    if args.schedule_pre == 'poly':
        title += '_pw{}'.format(args.power_pre)

    if args.fea_norm:
        title += '_feanorm_temp{}'.format(args.classifier_temp)
    elif args.fea_norm2:
        title += '_feanorm2_temp{}'.format(args.classifier_temp)

    if args.no_bias:
        title += '_nobias'


    plt.title(title)
    
    file_path = './pretrain/{}.png'.format(title)
    plt.savefig(file_path)

def best_epoch(valid_acc):
    best_epoch = 4
    best_acc = sum(valid_acc[0:5]) / 5
    for ep in range(5, len(valid_acc)):
        cur_acc = sum(valid_acc[ep-4:ep+1]) / 5
        if cur_acc > best_acc:
            best_acc = cur_acc
            best_epoch = ep
    best_acc = round(best_acc.item(), 2)
    
    return best_epoch, best_acc

def best_loss(valid_loss):
    best_epoch = 4
    best_loss = sum(valid_loss[0:5]) / 5
    for ep in range(5, len(valid_loss)):
        cur_loss = sum(valid_loss[ep-4:ep+1]) / 5
        if cur_loss < best_loss:
            best_loss = cur_loss
            best_epoch = ep
    best_loss = round(best_loss, 4)

    return best_epoch, best_loss

def print_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_hours = int(elapsed_time / 3600)
    elapsed_minutes = int((elapsed_time % 3600) / 60)
    print(f"\nElapsed time: {elapsed_hours} hours and {elapsed_minutes} minutes")