import argparse
import logging
import os
import pickle
import math
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import random

# import torchviz as tv

from utils.check_dataset import check_dataset
from utils.check_model import check_model
from utils.resnet_icml_ilsvrc import resnet18
from utils.common import count_parameters, AverageMeter, accuracy, set_logging_config,getUniqueFileHandler,load_weights_to_flatresnet
from utils import decision_makers, resnet_icml_ilsvrc, gumbel_softmax
from arch import resnet_atl, vgg_atl

dm_algo_names = []
for cls in decision_makers.__dict__.values():
    try:
        if issubclass(cls, decision_makers.BaseDecisionMaker) and not cls is decision_makers.BaseDecisionMaker:
            dm_algo_names.append(cls.name)
    except:
        continue


def get_accuracy_gain_reward(output, target, topk=(1,), prev_output=None):
    batch_size = target.size(0)
    _, predicted = torch.max(output.data, 1)
    prev_acc = 0.
    if prev_output is not None:
        _, prev_predicted = torch.max(prev_output.data, 1)
        prev_acc = (prev_predicted == target).sum().item()
    curr_acc= (predicted == target).sum().item()
    reward = (curr_acc - prev_acc)/batch_size
    return reward


def tget_predicted_gain_reward(output, target, prev_output=None):
    reward = 0
    loss = F.cross_entropy(output, target)
    if prev_output is not None:
        loss_old = F.cross_entropy(prev_output, target)
        reward = loss_old - loss
    if torch.isnan(reward):
        reward.zero_()
    return reward


def main():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--dataroot', required=True, help='Path to the dataset')
    parser.add_argument('--dataset', default='cub200')
    parser.add_argument('--datasplit', default='cub200')
    parser.add_argument('--datanoise', action='store_true', default=False)
    parser.add_argument('--batchSize', type=int, default=64, help='Input batch size')
    parser.add_argument('--workers', type=int, default=4)
    parser.add_argument('--seed', type=int, default=0)

    parser.add_argument('--source-model', default='resnet34', type=str)
    parser.add_argument('--source-domain', default='imagenet', type=str)
    parser.add_argument('--source-path', type=str, default=None)
    parser.add_argument('--pretrained-src-model', action='store_true', default=True)
    parser.add_argument('--target-model', default='resnet18', type=str)
    parser.add_argument('--pretrained-tg-model', action='store_true', default=False)
    parser.add_argument('--numTrain', type=int, default=100, help='Train sample size. 100 means use all')
    parser.add_argument('--ru-units', action='store_false', default=True)

    parser.add_argument('--transfer-mode', default='xstitch', type=str,
                        choices=['iden','simpleadd','wtadd','lincombine','factred', 'indep','spottune', 'attention'],
                        help='Different options to combine the source and target')
    parser.add_argument('--experiment', default='logs', help='Where to store models')
    parser.add_argument('--pairs', type=str, required=True,
                        help='auto or Write in format: 4-4,4-3,4-2,4-1,3-4,3-3,3-2,3-1,2-4,2-3,2-2,2-1,1-4,1-3,1-2,1-1')

    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--lr', type=float, default=0.1, help='Initial learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--wd', type=float, default=0.0001, help='Weight decay')
    parser.add_argument('--nesterov', action='store_true')
    parser.add_argument('--schedule', action='store_true', default=True)

    parser.add_argument('--meta-lr', type=float, default=0.001, help='Initial learning rate for meta networks')
    parser.add_argument('--meta-wd', type=float, default=0.0001)
    parser.add_argument('--optimizer', type=str, default='sgd')
    parser.add_argument('--src-optimizer', type=str, default='adam')

    parser.add_argument('--dm-algo', default='EXP3', choices=dm_algo_names)
    parser.add_argument('--dm-lr', default=0.01, type=float, help='the learning rate for the decision maker')
    parser.add_argument('--T', type=int, default=2)
    parser.add_argument('--warm-start', action='store_true', default=False)
    parser.add_argument('--warm-start-epoch', type=int, default=100)

    parser.add_argument('--exp3_gamma', type=float, default=0.2, help='Exp3 exploration parameter')
    parser.add_argument('--exploration', type=float, default=0.2, help='exploration parameter')
    parser.add_argument('--exploration_gamma', type=float, default=1e6, help='an exploration decay parameter. '
            'exploration is computed as exp = (gamma/(timestep + gamma))')
    parser.add_argument('--anneal', action='store_true', default=False)
    parser.add_argument('--exploit_reset', action='store_true', default=False)
    parser.add_argument('--exploit_epoch', type=int, default=200)

    parser.add_argument('--reward_clip_min', type=float, default=-1.0, help='Clip rewards to this min value')
    parser.add_argument('--reward_clip_max', type=float, default=1.0, help='Clip rewards to this max value')

    # default settings
    opt = parser.parse_args()
    pair_str = 'auto'
    if pair_str != opt.pairs:
        pair_str = 'fixed'
    opt.results_filename = 'results/' + pair_str + '_route_' + opt.dataset + '_' + opt.dm_algo + '_' + opt.source_model + 'Vs' + opt.target_model + '_' + str(
        opt.batchSize) + '_' + opt.transfer_mode + '_runId' + str(opt.seed) + '_' + str(opt.epochs) + '_' + str(opt.lr)
    results_ofile = getUniqueFileHandler(opt.results_filename + '_results')

    def get_arch(model_name):
        if 'resnet' in model_name:
            return 'resnet_atl'
        elif 'vgg' in model_name:
            return 'vgg_atl'
        else:
            return None
    src_net = eval(f'{get_arch(opt.source_model)}')#.__dict__[opt.source_model]()
    target_net = eval(f'{get_arch(opt.target_model)}')#.__dict__[opt.target_model]()

    opt.source_feature_ids = src_net.__dict__[src_net.__dict__['__all__'][0]].source_feature_ids
    opt.target_decisioner_ids = target_net.__dict__[target_net.__dict__['__all__'][0]].target_route_ids
    opt.comb_ops = list(target_net.OPS.keys())
    # opt.source_feature_ids = src_net.source_feature_ids #[0, 1, 2, 3, 4]
    # opt.target_decisioner_ids = [3]
    # opt.target_decisioner_ids = target_net.target_route_ids #[0, 1, 2, 3]
    opt.source_input_pass_id = 5  # ID corresponding to PASS/skip action
    feat2id = {i: ids for i, ids in
               enumerate(opt.source_feature_ids + [opt.source_input_pass_id])}  # maps arm d to the module
    ops2id = {i: ids for i, ids in
                     enumerate(opt.comb_ops)}  # maps decisioner to the module
    # decisioner2id = {i: ids for i, ids in
    #            enumerate(opt.target_decisioner_ids)}  # maps decisioner to the module
    arms = [(x0, y0) for x0 in feat2id for y0 in ops2id]
    arms2id = {i: (f,o) for i, (f,o) in
              enumerate(arms)}
    opt.narms = len(arms)
    # opt.transfer_types = [opt.transfer_type for _ in opt.source_feature_ids]


    # Initialize Decision Maker
    decisioners = [None for _ in range(len(opt.target_decisioner_ids))]
    if opt.pairs == 'auto':
        dm_algo = eval(f'decision_makers.{opt.dm_algo}')
        last_dm = None
        for tt in opt.target_decisioner_ids:
            new_dm = dm_algo(opt.narms, settings=opt, next_qlearner=last_dm)
            last_dm = new_dm
            decisioners[tt] = new_dm


    opt.in_epochs = 1
    opt.b = 10
    opt.max_dec_no = 50
    opt.average_all = False
    results = {}
    results['rewards'] = []
    results['loss'] = []
    results['train_accuracy'] = []
    results['valid_accuracy'] = []
    results['test_accuracy'] = []
    results['pairs'] = []

    # Seeds
    random.seed(opt.seed)
    cudnn.benchmark = True
    torch.manual_seed(opt.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(opt.seed)
    cudnn.enabled = True
    np.random.seed(opt.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_logging_config(opt.experiment)
    logger = logging.getLogger('main')
    logger.info(' '.join(os.sys.argv))
    logger.info(opt)

    # load dataloaders
    loaders = check_dataset(opt)

    # load source model
    if opt.source_domain == 'imagenet':
        # In this project sample code, we assumed the source network is always resnet pretrained with Imagenet
        source_model = src_net.__dict__[opt.source_model](pretrained=opt.pretrained_src_model).to(device)
    else:
        opt.model = opt.source_model
        source_path = opt.source_path
        ckpt = torch.load(source_path, map_location=device)
        # Number of class needs to be adjusted here
        source_model = src_net.__dict__[opt.source_model](num_classes=40, pretrained=opt.pretrained_src_model)
        source_model.load_state_dict(ckpt['target_model'])
        source_model.to(device)

    opt.source_info = (opt.source_feature_ids, opt.source_input_pass_id, decisioners, source_model.channels)
    if opt.transfer_mode == 'spottune':
        opt.source_info = None
        # opt.pretrained_tg_model = True

    # load target model
    opt.num_target_classes = len(loaders[0].dataset.classes)
    target_model = target_net.__dict__[opt.target_model](num_classes=opt.num_target_classes,
                                                  pretrained=opt.pretrained_tg_model,
                                                  init_weights=True,
                                                  source_info=opt.source_info,
                                                  ru_units=opt.ru_units)
    # if 'resnet' in opt.target_model:
    #     target_model = resnet_atl.__dict__[opt.model](num_classes=opt.num_target_classes,
    #                                                   pretrained=opt.pretrained_tg_model,
    #                                                   init_weights=True,
    #                                                   transfer_types=opt.transfer_types,
    #                                                   source_info=opt.source_info,
    #                                                   ru_units=opt.ru_units)
    # elif 'vgg' in opt.target_model: # VGG doesn't support pretrained imagenet version
    #     assert (not opt.pretrained_tg_model), "VGG doesn't support pretrained imagenet version"
    #     target_model = vgg_atl.__dict__[opt.model](num_classes=opt.num_target_classes,
    #                                                pretrained=opt.pretrained_tg_model,
    #                                                init_weights=True,
    #                                                transfer_types=opt.transfer_types,
    #                                                source_info=opt.source_info,
    #                                                ru_units=opt.ru_units)
    # else:
    #     target_model = None

    if opt.transfer_mode == 'spottune':
        load_weights_to_flatresnet(target_model)
        policy_net = resnet_atl.resnet10(num_classes=target_model.nlayers * 2).to(device)

        # Freeze parallel blocks
        for name, m in target_model.named_modules():
            if isinstance(m, nn.Conv2d) and 'parallel_' in name:
                m.weight.requires_grad = False

    target_model = target_model.to(device) # Move to GPU after loading pretrained model if available

    print('# Parameters:' + str(count_parameters(target_model)))

    # Optimizers
    target_params = [param for name, param in target_model.named_parameters()
                     if ('cs' not in name and 'ru' not in name and 'policy' not in name) and param.requires_grad is True]
    if opt.optimizer == 'sgd':
        target_optimizer = optim.SGD(target_params, lr=opt.lr, momentum=opt.momentum,
                                     weight_decay=opt.wd)
    else:
        target_optimizer = optim.Adam(target_params, lr=opt.lr, weight_decay=opt.wd)

    if opt.transfer_mode == 'spottune':
        weight_params = [param for name, param in policy_net.named_parameters()]
    else:
        weight_params = [param for name, param in target_model.named_parameters()
                         if ('cs' in name or 'ru' in name or 'policy' in name) and param.requires_grad is True]
    # for name, param in target_model.named_parameters():
    #     if ('cs' in name or 'ru' in name):
    #         print(name,param.requires_grad)

    if opt.src_optimizer == 'sgd':
        source_optimizer = optim.SGD(weight_params, lr=opt.meta_lr, weight_decay=opt.meta_wd,
                                     momentum=opt.momentum)
    else:
        source_optimizer = optim.Adam(weight_params, lr=opt.meta_lr, weight_decay=opt.meta_wd)

    init_state_dict = target_model.state_dict()
    state = {
        'target_model': target_model.state_dict(),
        'target_optimizer': target_optimizer.state_dict(),
        'source_optimizer': source_optimizer.state_dict(),
        'opt': opt,
        'best': (0.0, 0.0)
    }
    # print([param for name, param in target_model.layer1.cs[0].parameters()])

    # Each DecisionMaker (one for each target layer) maps Target Intermediate layer
    # to one of the K Source Intermediate layers (arms)
    # K arms are the number of source intermediate layers + Skip arms.
    # Pairs are in list of (target_layer_id,source_layer_id)
    pairs = []
    if not (opt.pairs == 'auto'):
        if opt.transfer_mode == 'attention':
            pairs = [(id, -1,'wtadd') for id in range(target_model.nlayers)]
        else:
            pairs = [(int(pair.split('-')[0]), int(pair.split('-')[1]), opt.transfer_mode) for pair in opt.pairs.split(',') if opt.pairs.strip()!='']
    state['pairs'] = pairs

    src_scheduler = optim.lr_scheduler.CosineAnnealingLR(source_optimizer, opt.epochs)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(target_optimizer, opt.epochs)

    def validate(model, loader):
        acc = AverageMeter()
        model.eval()
        if opt.transfer_mode == 'spottune':
            policy_net.eval()
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            if opt.transfer_mode == 'spottune':
                policy_probs = policy_net(x)[0]
                action = gumbel_softmax.gumbel_softmax(policy_probs.view(policy_probs.size(0), -1, 2))
                policy_weights = action[:, :, 1]
                y_pred, _ = model(x, weights=policy_weights)
            else:
                with torch.no_grad():
                    source_out, source_features = source_model(x)
                y_pred, _ = model(x, source_features, state['pairs'])
            acc.update(accuracy(y_pred.data, y, topk=(1,))[0].item(), x.size(0))
        return acc.avg

    def train_objective(data):
        x, y = data[0].to(device), data[1].to(device)

        if opt.transfer_mode == 'spottune':
            policy_probs = policy_net(x)[0]
            action = gumbel_softmax.gumbel_softmax(policy_probs.view(policy_probs.size(0), -1, 2))
            policy_weights = action[:, :, 1]
            y_pred, target_features = target_model(x, weights=policy_weights)
        else:
            with torch.no_grad():
                source_out, source_features = source_model(x)
            y_pred, target_features = target_model(x, source_features, state['pairs'])
        state['accuracy'] = accuracy(y_pred.data, y, topk=(1,))[0].item()
        state['route_wts'] = '<>'
        # state['route_wts'] = '<>'+'|'.join([str(format(target_model.layer1.cs[ii].wt[0].mean().item(), ".2f")) + '-' + str(format(
        #     target_model.layer1.cs[ii].wt[1].mean().item(), ".2f")) for ii in range(4)])+'<>'
        # state['route_wts'] += '<>'+'|'.join([str(format(target_model.layer2.cs[ii].wt[0].mean().item(), ".2f")) + '-' + str(format(
        #     target_model.layer2.cs[ii].wt[1].mean().item(), ".2f")) for ii in range(4)])+'<>'
        # state['route_wts'] += '<>'+'|'.join([str(format(target_model.layer3.cs[ii].wt[0].mean().item(), ".2f")) + '-' + str(format(
        #     target_model.layer3.cs[ii].wt[1].mean().item(), ".2f")) for ii in range(4)])+'<>'
        # state['route_wts'] += '<>'+'|'.join([str(format(target_model.layer4.cs[ii].wt[0].mean().item(), ".2f")) + '-' + str(format(
        #     target_model.layer4.cs[ii].wt[1].mean().item(), ".2f")) for ii in range(4)])+'<>'
        loss = F.cross_entropy(y_pred, y)
        state['loss'] = loss.item()

        return loss

    state['iter'] = 0
    state['reward'] = 0
    state['weights'] = '<>'
    pair_type = opt.pairs
    for epoch in range(opt.epochs):
        state['epoch'] = epoch
        source_model.eval()
        if opt.schedule:
            scheduler.step()
            src_scheduler.step()

        # 1) Decision Maker Stage
        val_loader_iter = iter(loaders[1])
        target_model.eval()
        dec_loss = 0.
        best_pairs = []
        best_score = float("-inf")
        if opt.pairs == 'auto':
            if opt.anneal and epoch > opt.exploit_epoch:
                pairs = []
                for t in range(target_model.nlayers):
                    decisioner = decisioners[t] #target_model.layers[t].decisioner
                    if not decisioner:
                        pairs.append((t, opt.source_input_pass_id, opt.transfer_type))
                        continue
                    ref_id = -1
                    if not opt.average_all:
                        s = decisioner.probabilities.argmax()
                        ref_id = feat2id[s]
                        if ref_id == decisioner.number_of_arms-1: # Last arm is reserved for the PASS
                            ref_id = opt.source_input_pass_id
                    pairs.append((t, ref_id, opt.transfer_type))
                state['pairs'] = pairs
                opt.pairs = '' # Turn-off decision making for exploit-only <FINAL> phase
            else:
                # Choose pairs using decisioner over several iterations
                for dec_it in range(opt.max_dec_no):
                    # Decision Maker: Selection
                    pairs = []
                    sel_actions = []
                    for t in range(target_model.nlayers):
                        decisioner = decisioners[t] #target_model.layers[t].decisioner
                        if not decisioner:
                            pairs.append((t, opt.source_input_pass_id, 'iden'))
                            continue
                        ref_id, transfer_type = opt.source_input_pass_id, 'iden'
                        s = decisioner.step()
                        if not opt.average_all:
                            f, o = arms2id[s]
                            ref_id, transfer_type = feat2id[f], ops2id[o]
                            # if ref_id == decisioner.number_of_arms-1:  # Last arm is reserved for the PASS
                            #     ref_id = opt.source_input_pass_id
                        sel_actions.append(s)
                        pairs.append((t, ref_id, transfer_type))
                    state['pairs'] = pairs

                    # Decision Maker: Evaluate
                    try:
                        val_data = next(val_loader_iter)
                        if val_data[1].size(0) != opt.batchSize:
                            val_data = next(val_loader_iter)
                    except StopIteration:
                        val_loader_iter = iter(loaders[1])
                        val_data = next(val_loader_iter)
                    val_x, val_y = val_data[0].to(device), val_data[1].to(device)
                    with torch.no_grad():
                        val_source_out, val_source_features = source_model(val_x)
                        y_pred, target_features = target_model(val_x, val_source_features, state['pairs'])
                        y_pred_new, _ = target_model(val_x)
                    # Get reward
                    # raw_reward = get_accuracy_gain_reward(y_pred.data, val_y, prev_output=y_pred_new.data)
                    raw_reward = get_predicted_gain_reward(y_pred.data, val_y, prev_output=y_pred_new.data)
                    raw_reward = raw_reward.cpu()
                    # raw_reward = get_predicted_gain_reward(y_pred.data, val_y) # Use .cpu()
                    # raw_reward = raw_reward / 5
                    reward = np.clip(raw_reward, opt.reward_clip_min, opt.reward_clip_max)
                    state['reward'] = reward
                    state['weights'] = '<>'

                    for i, (action, decisioner) in enumerate(zip(sel_actions, decisioners)):
                        next_action = sel_actions[i] if i < len(sel_actions) else None
                        f_entropy = 0.
                        if np.sum(decisioner.probabilities)==1.0:
                            entropy_score = [-pj * math.log(pj) for pj in decisioner.probabilities]
                            f_entropy = np.sum(entropy_score)
                        # Reward focused on low entropy to peak on the optimal actions
                        # reward_i = reward
                        # reward_i = reward *(opt.b/(opt.b+np.sum(entropy_score)))
                        reward_i = reward - 0.05 * f_entropy  # maximiza reward and minimize the entropy
                        # print(raw_reward, np.sum(entropy_score), (opt.b / (opt.b + np.sum(entropy_score))), reward_i)
                        reward_i = np.clip(reward_i, opt.reward_clip_min, opt.reward_clip_max)
                        dec_loss = decisioner.update(action, reward_i, next_action)
                        state['weights'] += ('[' + ('-'.join(format(x, ".2f") for x in decisioner.probabilities)) + ']<>')
                    if raw_reward > best_score:
                        best_score = raw_reward
                        best_pairs = pairs
                    # print(pairs,raw_reward,best_pairs,best_score)
                state['pairs'] = best_pairs

        # 2) Model Learning Stage
        target_model.train()
        if opt.transfer_mode == 'spottune':
            policy_net.train()
        train_acc = AverageMeter()
        # en(loaders[0])
        for iepoch in range(opt.in_epochs):
            state['iepoch'] = iepoch
            for i, train_data in enumerate(loaders[0]):
                source_optimizer.zero_grad()
                target_optimizer.zero_grad()
                # Network Update
                loss = train_objective(train_data)
                (loss+dec_loss).backward()
                nn.utils.clip_grad_norm_(target_params, 5.)
                target_optimizer.step()
                source_optimizer.step()
                train_acc.update(state['accuracy'], train_data[0].size(0))
                logger.info(
                    '[Epoch {:3d}.{}] [Iter {:3d}] [Loss {:.4f}] [Acc {:.4f}] [Rew {:.2f}] [Pairs {}] [Wts {}] [Route Wts {}]'.format(
                        state['epoch'], state['iepoch'], state['iter'], state['loss'], state['accuracy'], state['reward'],
                        ''.join(map(str, state['pairs'])), state['weights'], state['route_wts']))
                state['iter'] += 1

        state['accuracy'] = train_acc.avg
        acc = (validate(target_model, loaders[1]),
               validate(target_model, loaders[2]))

        if state['best'][0] < acc[0]:
            state['best'] = acc

        if (state['epoch'] % 50) + 1 == 50:
            state['target_model'] = target_model.state_dict()
            state['target_optimizer'] = target_optimizer.state_dict()
            state['source_optimizer'] = source_optimizer.state_dict()
            torch.save(state, getUniqueFileHandler(opt.results_filename, 'ckpt-{}'.format(state['epoch']+1)))
        # Collect Results
        results['train_accuracy'].append(state['accuracy'])
        results['valid_accuracy'].append(acc[0])
        results['test_accuracy'].append(acc[1])
        results['pairs'].append(state['pairs'])
        results['rewards'].append(state['reward'])
        results['loss'].append(state['loss'])
        logger.info('             [Epoch {}] [train {:.4f}] [val {:.4f}] [test {:.4f}] [best {:.4f}]'
                    .format(epoch, state['accuracy'], acc[0], acc[1], state['best'][1]))

    pickle.dump(results, results_ofile)

if __name__ == '__main__':
    main()
