import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
import random
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_target_dataset, get_network, get_eval_pool, evaluate_synset, get_time, DiffAugment, ParamDiffAug
from utils2 import get_pretrained_network, common_corruptions, get_rampup_weight
from losses import MMD_loss
import wandb


import argparse
import torch

from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer

from dassl.engine.trainer_myada import SimpleNet
from utils2 import my_get_pretrained_network, my_get_pretrained_network_ossfda
from style_op import style_inject, get_styles, normalize_style


class SimpleNet1(SimpleNet):
    def forward_DM(self, x, style_index=None, styles=None, norm_layer=[], inject_layer=[], return_style=False):
        feat = self.backbone.forward_DM(x, style_index=style_index, styles=styles, norm_layer=norm_layer, inject_layer=inject_layer, return_style=return_style)
        return feat


def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def print_args(args, cfg):
    print("***************")
    print("** Arguments **")
    print("***************")
    optkeys = list(args.__dict__.keys())
    optkeys.sort()
    for key in optkeys:
        print("{}: {}".format(key, args.__dict__[key]))
    print("************")
    print("** Config **")
    print("************")
    print(cfg)


def reset_cfg(cfg, args):
    if args.root:
        cfg.DATASET.ROOT = args.root

    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir

    if args.resume:
        cfg.RESUME = args.resume

    if args.seed:
        cfg.SEED = args.seed

    if args.source_domains:
        cfg.DATASET.SOURCE_DOMAINS = args.source_domains

    if args.target_domains:
        cfg.DATASET.TARGET_DOMAINS = args.target_domains

    if args.transforms:
        cfg.INPUT.TRANSFORMS = args.transforms

    if args.trainer:
        cfg.TRAINER.NAME = args.trainer

    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone

    if args.head:
        cfg.MODEL.HEAD.NAME = args.head

    # for mire
    if args.nocls:
        cfg.OSDG.NO_CLS = args.nocls
    if args.adddims:
        cfg.OSDG.ADD_DIMS = args.adddims


    if args.lr:
        cfg.OPTIM.LR = args.lr
    if args.bs:
        cfg.DATALOADER.TRAIN_X.BATCH_SIZE = args.bs
    if args.epochs:
        cfg.OPTIM.MAX_EPOCH = args.epochs

    
    # for ADA
    cfg.ADA.bp_grl = args.bp_grl
    cfg.ADA.mining_grl = args.mining_grl
    cfg.ADA.topk = args.topk
    cfg.ADA.mining_th = args.mining_th
    cfg.ADA.fda_loss_coef = args.fda_loss_coef
    cfg.ADA.ua_loss_coef = args.ua_loss_coef
    cfg.ADA.ua_loss_coef1 = args.ua_loss_coef1
    cfg.ADA.penalty_coef = args.penalty_coef
    cfg.ADA.smooth_coef = args.smooth_coef
    cfg.ADA.warmup_epoch = args.warmup_epoch
    cfg.ADA.adv_grl = args.adv_grl
    cfg.TEST.EVALUATOR = args.evaluator
    cfg.ADA.TTA.OPTIM.LR = args.tta_lr
    cfg.ADA.TTA.epoch = args.tta_epoch


def extend_cfg(cfg):
    """
    Add new config variables.

    E.g.
        from yacs.config import CfgNode as CN
        cfg.TRAINER.MY_MODEL = CN()
        cfg.TRAINER.MY_MODEL.PARAM_A = 1.
        cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
        cfg.TRAINER.MY_MODEL.PARAM_C = False
    """
    pass


def setup_cfg(args):
    cfg = get_cfg_default()
    extend_cfg(cfg)

    # 1. From the dataset config file
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)

    # 2. From the method config file
    if args.config_file:
        cfg.merge_from_file(args.config_file)

    # 3. From input arguments
    reset_cfg(cfg, args)

    # 4. From optional input arguments
    # cfg.merge_from_list(args.opts)

    cfg.freeze()

    return cfg


def get_cfg(args):
    cfg = setup_cfg(args)
    if cfg.SEED >= 0:
        print("Setting fixed seed: {}".format(cfg.SEED))
        set_random_seed(cfg.SEED)
    setup_logger(cfg.OUTPUT_DIR)

    if torch.cuda.is_available() and cfg.USE_CUDA:
        torch.backends.cudnn.benchmark = True

    print_args(args, cfg)
    print("Collecting env info ...")
    print("** System info **\n{}\n".format(collect_env_info()))
    

    return cfg



def main(cfg, add_args):
    
    cfg = cfg

    parser = argparse.ArgumentParser(description='Parameter Processing for Dataset Condensation')
    parser.add_argument('--DM_dataset', type=str, default='pacs', help='dataset', choices=['pacs', 'office_home_dg', 'office31'])
    parser.add_argument('--data_path', type=str, default='/datasets/PACSori/', help='dataset path')
    parser.add_argument('--pretrained', action='store_true')
    parser.add_argument('--resume_path', type=str, default='resume', help='path to load pretrained weights')
    parser.add_argument('--resume_classifier_path', type=str, default='resume', help='path to load pretrained weights')
    parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    parser.add_argument('--batch_real', type=int, default=64, help='batch size for real data')
    parser.add_argument("--norm-layer", nargs='+', type=int, default=[])
    parser.add_argument('--style-norm', action='store_true')
    parser.add_argument('--style-norm-realonly', action='store_true')


    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
    parser.add_argument('--num_exp', type=int, default=5, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=20, help='the number of evaluating randomly initialized models')
    # parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
    parser.add_argument('--Iteration', type=int, default=20000, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
    # parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
    
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--loss_weight', type=float, default=1.0)
    parser.add_argument('--ce_loss_weight', type=float, default=0.0)
    parser.add_argument('--num_classes_inbatch', type=int, default=None)

    parser.add_argument('--corruption', type=str, default=None, choices=common_corruptions)
    parser.add_argument('--train-corruption', type=str, default=None, choices=common_corruptions)

    parser.add_argument('--parameters', type=str, default=None)

    parser.add_argument('--warmup-type', type=str, default='step')
    parser.add_argument('--warmup-length', type=int, default=0)

    args = parser.parse_args(add_args)
    set_seed(args.seed)
    
    args.method = 'DM'
    # args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
    args.pretrained=True


    # if not os.path.exists(args.data_path):
    #     os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    eval_it_pool = np.arange(0, args.Iteration+1, args.Iteration // 2).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
    # eval_it_pool = eval_it_pool[1:]
    print('eval_it_pool: ', eval_it_pool)

    save_it_pool = np.arange(0, args.Iteration+1, args.Iteration // 10).tolist()

    # channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.DM_dataset, args.data_path)
    channel, im_size, num_classes, class_names, mean, std, dst_train, label_train, trainloader, dst_test, testloader = \
        get_dataset(args.DM_dataset, args.data_path, cfg=cfg)

    args.num_classes = num_classes
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

    criterion = nn.CrossEntropyLoss().to(args.device)
    mmd_loss = MMD_loss


    eval_mode = ['default', 'styin']

    accs_all_exps = dict()

    for key in model_eval_pool:
        accs_all_exps[key] = dict()

        for mode in eval_mode:
            accs_all_exps[key][mode] = []
    
    data_save = []

    if args.pretrained:
        # net, pretrained_state_dict = get_pretrained_network(args, num_classes)

        net = SimpleNet1(cfg, cfg.MODEL, num_classes)
        net = my_get_pretrained_network_ossfda(net, args)
        net = net.cuda()

        
        # freeze model 
        for m in net.modules():
            for param in m.parameters():
                param.requires_grad = False

    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)
        print('Evaluation model pool: ', model_eval_pool)

        ''' organize the real dataset '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i], dim=0) for i in range(len(dst_train))] # check
        labels_all = [label_train[i] for i in range(len(label_train))]
        
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))
            if c > 20:
                break

        def get_images(c, n): # get random n images from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]
        
        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))
        

        ''' initialize the synthetic data '''

        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)

        label_syn = torch.cat([torch.ones(args.ipc, dtype=torch.long, requires_grad=False, device=args.device)* i for i in range(num_classes)])

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        elif args.init == 'clean':
            raise NotImplemented
        else:
            print('initialize synthetic data from random noise')
        


        ''' training '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
        optimizer_img.zero_grad()
        print('%s training begins'%get_time())

        step = 0

        for it in range(args.Iteration+1):
            rampup_weight = get_rampup_weight(it, args.warmup_length, args.warmup_type)

            if it in eval_it_pool:

                ''' visualize and save '''
                save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.DM_dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch] 
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.

                # del net_eval

            ''' Train synthetic data '''
            # if not args.pretrained:
            #     net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
            net.train()

            # freeze model
            for m in net.modules():
                for param in m.parameters():
                    param.requires_grad = False
            
            # for param in list(net.parameters()):
            #     param.requires_grad = False

            loss_avg = 0
            loss_avg_ce = 0

            ''' update synthetic data '''
            num_classes_inbatch = args.num_classes_inbatch if args.num_classes_inbatch else num_classes
            total_iteration = num_classes // num_classes_inbatch
            classes_indices = torch.randperm(num_classes)
            
            for sub_idx in range(total_iteration if num_classes % num_classes_inbatch == 0 else total_iteration + 1):
                optimizer_img.zero_grad()

                loss = torch.tensor(0.0).to(args.device)
                ce_loss = torch.tensor(0.0).to(args.device)

                start_index = num_classes_inbatch * sub_idx
                end_index = min(num_classes_inbatch * (sub_idx + 1), num_classes)

                images_real_all = []
                images_syn_all = []
                labels_real_all = []
                labels_syn_all = []

                for cls_idx in range(start_index, end_index): 
                    step = it * num_classes + cls_idx
                    c = classes_indices[cls_idx].item()

                    img_real = get_images(c, args.batch_real)
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                    lb_syn = label_syn[c*args.ipc:(c+1)*args.ipc]

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    images_real_all.append(img_real)
                    images_syn_all.append(img_syn)
                    labels_real_all.append(torch.ones(args.batch_real, dtype=torch.long) * c)
                    labels_syn_all.append(lb_syn)
                
                images_real_all = torch.cat(images_real_all, dim=0)
                images_syn_all = torch.cat(images_syn_all, dim=0)
                labels_real_all = torch.cat(labels_real_all, dim=0)
                labels_syn_all = torch.cat(labels_syn_all, dim=0)

                images_real_all = images_real_all.to(args.device)
                images_syn_all = images_syn_all.to(args.device)
                labels_real_all = labels_real_all.to(args.device)
                labels_syn_all = labels_syn_all.to(args.device)

                repeated_labels_syn_all = labels_syn_all

                with torch.no_grad():
                    output_real = net.forward_DM(images_real_all, norm_layer=args.norm_layer if args.style_norm or args.style_norm_realonly else [])
                    output_real = output_real.detach() 

                output_syn = net.forward_DM(images_syn_all, norm_layer=args.norm_layer if args.style_norm else [])

                loss += mmd_loss(output_real.reshape(num_classes_inbatch, args.batch_real, -1), output_syn.reshape(num_classes_inbatch, args.ipc, -1))


                total_loss = loss * args.loss_weight
                total_loss *= rampup_weight 

                total_loss.backward()
                optimizer_img.step()

                loss_avg += loss.item()

            loss_avg /= (num_classes)


            if it%10 == 0:
                abs_mean_grad = torch.mean(torch.abs(image_syn.grad)).item()
                max_grad = torch.max(image_syn.grad).item()
                min_grad = torch.min(image_syn.grad).item()
                print('image gradient: %.4f, %.4f, %.4f' % (abs_mean_grad, max_grad, min_grad))

            if it%5 == 0:
                print('%s iter = %05d, loss(real, syn) = %.4f' % (get_time(), it, loss_avg))

            if it in eval_it_pool:
                data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                if it == args.Iteration:
                    torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.DM_dataset, args.model, args.ipc)))
                else:
                    torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc_%diter.pt'%(args.method, args.DM_dataset, args.model, args.ipc, it)))


    print('\n==================== Final Results ====================\n')



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default="data/PACSori/", help="path to dataset")
    parser.add_argument(
        "--output-dir", type=str, default="", help="output directory"
    )
    parser.add_argument(
        "--resume",
        type=str,
        default="",
        help="checkpoint directory (from which the training resumes)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="only positive value enables a fixed seed"
    )
    parser.add_argument(
        "--source-domains",
        type=str,
        nargs="+",
        help="source domains for DA/DG"
    )
    parser.add_argument(
        "--target-domains",
        type=str,
        nargs="+",
        help="target domains for DA/DG"
    )
    parser.add_argument(
        "--transforms", type=str, nargs="+", help="data augmentation methods"
    )
    parser.add_argument(
        "--config-file", type=str, default="/DasslOS/configs/trainers/dg/vanilla/pacs.yaml", help="path to config file"
    )
    parser.add_argument(
        "--dataset-config-file",
        type=str,
        default="/DasslOS/configs/datasets/dg/pacs.yaml",
        help="path to config file for dataset setup",
    )
    parser.add_argument(
        "--trainer", type=str, default="Vanilla2", help="name of trainer"
    )
    parser.add_argument(
        "--backbone", type=str, default="", help="name of CNN backbone"
    )
    parser.add_argument("--head", type=str, default="", help="name of head")
    parser.add_argument(
        "--eval-only", action="store_true", help="evaluation only"
    )
    parser.add_argument(
        "--model-dir",
        type=str,
        default="",
        help="load model from this directory for eval-only mode",
    )
    parser.add_argument(
        "--load-epoch",
        type=int,
        help="load model weights at this epoch for evaluation"
    )
    parser.add_argument(
        "--no-train", action="store_true", help="do not call trainer.train()"
    )

    #for OSDG
    parser.add_argument(
        "--nocls", type=list, default=['horse', 'house', 'person'], help=""
    )
    parser.add_argument(
        "--adddims", type=int, default=1,  help=""
    )

    parser.add_argument(
        "--lr", type=float, default=1e-3, help="learning rate"
    )
    parser.add_argument(
        "--bs", type=int, default=64, help="batch size"
    )
    parser.add_argument(
        "--epochs", type=int, default=30, help="epochs"
    )

    # ADA parameters
    parser.add_argument('--bp_grl', type=float, default=0.5, metavar='TH', help='grl adversarial weight (default: 0.5)')
    parser.add_argument('--mining_grl', type=float, default=0.2, metavar='TH', help='grad scaler (default: 0.2)')
    parser.add_argument('--topk', default=1, type=int, help='select potential unk regions, depends on number of known classes')
    parser.add_argument('--mining_th', default=1.0, type=float, metavar='TH', help='unk label')
    parser.add_argument('--fda_loss_coef', default=1.0, type=float)
    parser.add_argument('--ua_loss_coef', default=1.0, type=float)
    parser.add_argument('--ua_loss_coef1', default=0.0, type=float)
    parser.add_argument('--smooth_coef', default=1.0, type=float, help="smoothed CE, gt 1")
    parser.add_argument('--warmup_epoch', default=0, type=int)
    parser.add_argument('--penalty_coef', default=0.05, type=float)
    parser.add_argument('--evaluator', default="Classification_plain")
    parser.add_argument('--adv_grl', type=float, default=0.1, metavar='TH', help='grl adversarial weight (default: 0.1)')
    parser.add_argument('--tta_lr', default=1e-3, type=float)
    parser.add_argument('--tta_epoch', default=15, type=int)


    args, unknown = parser.parse_known_args()

    cfg = get_cfg(args)


    main(cfg, unknown)


