import argparse
import os, sys

sys.path.append("./")

import os.path as osp
import torchvision
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import network
from torch.utils.data import DataLoader
import random, pdb, math, copy
from tqdm import tqdm
import pickle
from utils import *
from torch import autograd
from warmup_scheduler import GradualWarmupScheduler
from masking import Masking
from torch.optim.lr_scheduler import CosineAnnealingLR
import clip_co
import prompt_tuning

from torch import autocast
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from diffusers.utils import logging
import logging as py_logging

import sys
sys.path.append('./office-home')
import torch
from omegaconf import OmegaConf
from PIL import Image

def Entropy(input_):
    bs = input_.size(0)
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.sum(entropy, dim=1)
    return entropy


def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s


def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group["lr0"] = param_group["lr"]
    return optimizer


def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=1):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group["lr"] = param_group["lr0"] * decay
        param_group["weight_decay"] = 1e-3
        param_group["momentum"] = 0.9
        param_group["nesterov"] = True
    return optimizer


class ImageList_idx(Dataset):
    def __init__(
        self, image_list, labels=None, transform=None, target_transform=None, mode="RGB"
    ):
        imgs = make_dataset(image_list, labels)

        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        if mode == "RGB":
            self.loader = rgb_loader
        elif mode == "L":
            self.loader = l_loader

    def __getitem__(self, index):
        path, target = self.imgs[index]
        # for visda
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index

    def __len__(self):
        return len(self.imgs)


def office_load_idx(args):
    train_bs = args.batch_size
    if args.home == True:
        ss = args.dset.split("2")[0]
        tt = args.dset.split("2")[1]
        if ss == "a":
            s = "Art"
        elif ss == "c":
            s = "Clipart"
        elif ss == "p":
            s = "Product"
        elif ss == "r":
            s = "Real_World"

        if tt == "a":
            t = "Art"
        elif tt == "c":
            t = "Clipart"
        elif tt == "p":
            t = "Product"
        elif tt == "r":
            t = "Real_World"

        s_tr, s_ts = "./data/office-home/{}.txt".format(
            s
        ), "./data/office-home/{}.txt".format(s)

        txt_src = open(s_tr).readlines()
        dsize = len(txt_src)

        s_tr = txt_src
        s_ts = txt_src

        t_tr, t_ts, tt = "./data/office-home/{}.txt".format(t, str(args.num)
                                ), "./data/office-home/{}.txt".format(t
                                                ), "./data/office-home/{}.txt".format(
            t, str(args.num))
        tl_tr, tl_tr = "./data/office-home/{}.txt".format(t, str(args.num)
                                ), "./data/office-home/{}.txt".format(t,
                                                    str(args.num))
        prep_dict = {}
        prep_dict["source"] = image_train()
        prep_dict["target"] = image_target()
        prep_dict["test"] = image_test()
        train_source = ImageList_idx(s_tr, transform=prep_dict["source"])
        test_source = ImageList_idx(s_ts, transform=prep_dict["source"])
        train_target = ImageList_idx(
            open(t_tr).readlines(), transform=prep_dict["target"]
        )
        test_target = ImageList_idx(open(t_ts).readlines(), transform=prep_dict["test"])

    dset_loaders = {}
    dset_loaders["source_tr"] = DataLoader(
        train_source,
        batch_size=train_bs,
        shuffle=True,
        num_workers=args.worker,
        drop_last=False,
    )
    dset_loaders["source_te"] = DataLoader(
        test_source,
        batch_size=train_bs * 2,  # 2
        shuffle=True,
        num_workers=args.worker,
        drop_last=False,
    )
    dset_loaders["target"] = DataLoader(
        train_target,
        batch_size=train_bs,
        shuffle=True,
        num_workers=args.worker,
        drop_last=False,
    )
    dset_loaders["test"] = DataLoader(
        test_target,
        batch_size=train_bs * 3,  # 3
        shuffle=True,
        num_workers=args.worker,
        drop_last=False,
    )
    
    return dset_loaders


def hyper_decay(x, beta=-5, alpha=1):
    weight = (1 + 1 * x) ** (-beta) * alpha
    return weight


def train_target_decay(args):
    dset_loaders, prompt = office_load_idx(args)
    ## set base network

    netF = network.ResNet_FE().cuda()
    netC = network.feat_classifier(
        type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck
    ).cuda()

    # oldC = network.feat_classifier(
    #     type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck
    # ).cuda()

    modelpath = args.output_dir + "/source_F.pt"
    netF.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir + "/source_C.pt"
    netC.load_state_dict(torch.load(modelpath))
    # oldC.load_state_dict(torch.load(modelpath))

    source_model = nn.Sequential(netF, netC).cuda()
    source_model.eval()

    clip_model, _, _ = clip_co.load(args.clip_load)
    clip_model = clip_model.cuda()
    text = clip_co.tokenize(prompt).cuda()

    config = OmegaConf.load("./DiffusionModel/stable-diffusion-main/v1-inference.yaml")
    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
         "/workspace/stable-diffusion-v1-5",
         safety_checker=None
     ).to("cuda")
    pipe.set_progress_bar_config(disable=True)
    pipe.progress_bar = lambda *args, **kwargs: None

    # acc1, _ = cal_acc_(dset_loaders["test"], netF, netC)  # 1
    # log_str = "Task: {}, Iter:{}/{}; Accuracy on target = {:.2f}%".format(
    #     args.dset, 0, 0, acc1 * 100
    # )
    # print(log_str)

    netF = network.ResNet_FE().cuda()
    netC = network.feat_classifier(
        type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck
    ).cuda()

    if args.warmup:
        args.lr *= 10

    optimizer = optim.SGD(
        [
            {"params": netF.feature_layers.parameters(), "lr": args.lr * 0.1},  # 1
            {"params": netF.bottle.parameters(), "lr": args.lr * 1},  # 10
            {"params": netF.bn.parameters(), "lr": args.lr * 1},  # 10
            {"params": netC.parameters(), "lr": args.lr * 1},  # 10
        ],
        momentum=0.9,
        weight_decay=5e-4,
        nesterov=True,
    )

    optimizer = op_copy(optimizer)

    if args.warmup:
        warmup_epochs = args.warmup_epochs
        scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_epoch - warmup_epochs, eta_min=1e-3)
        scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,
                                       after_scheduler=scheduler_cosine)
        scheduler.step()

    masking = Masking(
        block_size=args.mask_block_size,
        ratio=args.mask_ratio,
        blur=args.mask_blur,
        mean=args.norm_mean,
        std=args.norm_std)

    acc_init = 0
    start = True
    loader = dset_loaders["target"]
    num_sample = len(loader.dataset)
    fea_bank = torch.randn(num_sample, 256)
    score_bank = torch.randn(num_sample, args.class_num).cuda()
    ind_keep = torch.ones(num_sample).cuda()
    pse_label = torch.zeros(num_sample).long().cuda()
    clip_all = torch.randn(num_sample, args.class_num).cuda()
    diffusion_bank = None
    image_diffusion = []

    # score_source = torch.randn(num_sample, args.class_num).cuda()

    # num_labeled = len(loader_labeled.dataset)
    # fea_labeled = torch.randn(num_labeled, 256)
    # score_labeled = torch.randn(num_labeled, args.class_num).cuda()
    # true_label = torch.zeros(num_labeled).long()
    print(num_sample)

    netF.eval()
    netC.eval()
    # oldC.eval()
    with torch.no_grad():
        # class_cm = torch.randn(num_sample, args.class_num)
        # fea_tsne = torch.randn(num_sample, 256)

        iter_test = iter(loader)
        for i in range(len(loader)):
            data = next(iter_test)
            inputs = data[0]
            indx = data[-1]
            # print(data[-1])
            # labels = data[1]
            inputs = inputs.cuda()
            output = netF.forward(inputs)  # a^t
            output_norm = F.normalize(output)
            outputs_o = netC(output)
            outputs_old = nn.Softmax(-1)(outputs_o)
            outputs = nn.Softmax(-1)(outputs_o)
            fea_bank[indx] = output_norm.detach().clone().cpu()
            # score_bank[indx] = outputs.detach().clone()  # .cpu()
            values, indices = torch.max(outputs, dim=-1)
            pse_label[indx] = indices.detach().clone()

            score_bank[indx] = nn.Softmax(dim=1)(source_model(inputs)).detach().clone()
            # print((pse_label[indx] == data[1].cuda()).sum()/inputs.size(0))
            # print('___')

            # class_cm[indx] = outputs_cm.detach().clone().cpu()
            # fea_tsne[indx] = output.detach().clone().cpu()

        args.score_bank = score_bank



        acc1, _ = cal_acc_(dset_loaders["test"], netF, netC)  # 1
        # print("source")
        log_str = "Task: {}, Iter:{}/{}; Accuracy on target = {:.2f}%".format(
            args.dset, 0, 0, acc1 * 100
        )
        args.out_file.write(log_str + "\n")
        args.out_file.flush()
        print(log_str)
    text_features = prompt_tuning.prompt_main(args, dset_loaders["diffusion"].dataset.imgs,
                                                  score_bank.detach(),pipe)
    with torch.no_grad():
        
        to_tensor = transforms.ToTensor()
        iter_diffusion = iter(dset_loaders["diffusion"])
        for i in range(len(dset_loaders["diffusion"])):
            data = next(iter_diffusion)
            inputs = data[0]
            indx = data[-1]
            with torch.autocast("cuda"):
                result = pipe(prompt="A clearer photo of {}".format(lower(arg.tt)), image=mask_images, strength=0.6, disable_progress_bar=True)
            if diffusion_bank == None:
                diffusion_bank = to_tensor(result.images[0])
            else:
                diffusion_bank = torch.cat((diffusion_bank, to_tensor(result.images[0])),  dim=0)
            mask_images = masking(inputs,result.images[0])
            image_features = clip_model.encode_image(inputs,mask_images)
            logit_scale = clip_model.logit_scale.data
            logit_scale = logit_scale.exp().cpu()
            image_features = image_features / image_features.norm(dim=1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            logits = logit_scale * image_features @ text_features.t()
            softmax_logits = nn.Softmax(dim=1)(logits)
            vals, inds = torch.max(softmax_logits, dim=-1)
            diffusion_all[indx] = softmax_logits.float()
      
        for j,image in enumerate(diffusion_bank):
            image = image.cpu().numpy()
            image = (image * 255).astype(np.uint8)
            image_pil = Image.fromarray(np.transpose(image, (1, 2, 0)))
            image_pil.save(f'./diffusion_image/saved_image_{j}.jpg')
            image_diffusion.append(image_pil)
    train_target = ImageList_idx(
            image_diffusion, transform=image_test()
        )
    dset_loaders["diffusion"] = DataLoader(
        train_target,
        batch_size=1,
        shuffle=True,
        num_workers=args.worker,
        drop_last=False,
    )
        

                # pse_test = pse_label.cpu()
            # ind_keep = ind_keep.cpu()
            # print(data[1][ind_keep[indx] == 1], pse_test[indx][ind_keep[indx] == 1])
        # np.savetxt('start_cm.txt', class_cm)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval
    iter_num = 0
    epoch_n = 0
    decay_ema = decay_rate(args.ema, max_iter)
    flag_clip =0
    ff=1

    netF.train()
    netC.train()
    # oldC.train()


    while iter_num < max_iter:

        netF.train()
        netC.train()
        try:
            inputs_test, _, tar_idx = next(iter_target)
            flag_clip = 0
        except:
            iter_target = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = next(iter_target)
            epoch_n += 1
            flag_clip = 1

        if flag_clip == 1 and epoch_n < args.max_epoch//2:
            flag_clip = 0
            args.score_bank = score_bank
            text_features = prompt_tuning.prompt_main(args, dset_loaders["diffusion"].dataset.imgs,
                                                      diffusion_all.detach())
            iter_diffusion= iter(dset_loaders["diffusion"])
            predictor_diffusion(iter_diffusion,diffusion_bank, pipe, diffusion_all)

        elif epoch_n >= args.max_epoch//2 and ff == 1:
            ff = 0
            iter_diffusion= iter(dset_loaders["diffusion"])
            finetuning(iter_diffusion, text_features, diffusion_all, prompt_tuning, args)

        if inputs_test.size(0) == 1:
            continue

        alpha = (1 + 10 * iter_num / max_iter) ** (-args.beta) * 1

        inputs_test = inputs_test.cuda()

        iter_num += 1
        # lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
        """for _, (inputs_target, _,indx) in enumerate(iter_target):
            if inputs_target.size(0) == 1:
                continue"""
        inputs_target = inputs_test.cuda()

        inputs_masked = masking(inputs_target, diffusion_bank[tar_idx]).cuda() #inputs_test.cuda()

        features_mask = netF(inputs_masked)
        features_test = netF(inputs_target)
        # print(netF.mask.max())

        output = netC(features_test)
        softmax_out = nn.Softmax(dim=1)(output)
        output_re = softmax_out.unsqueeze(1)  # batch x 1 x num_class

        outputs_o = netC(features_mask)
        outputs_old = nn.Softmax(dim=1)(outputs_o)

        with torch.no_grad():
            output_f_norm = F.normalize(features_test)
            output_f_ = output_f_norm.cpu().detach().clone()
            output_f_norm_mask = F.normalize(features_mask)
            output_f_mask = output_f_norm_mask.cpu().detach().clone()

            values, indices = torch.max(softmax_out, dim=-1)
            pse_label[tar_idx] = indices.detach().clone()

            pred_bs = softmax_out

            fea_bank[tar_idx] = output_f_.detach().clone().cpu()
            # score_bank[tar_idx] = pred_bs.detach().clone()

            ema = decay_ema(iter_num) if args.decay_ema else 0.6
            score_bank[tar_idx] = score_bank[tar_idx] * ema + softmax_out.detach() * (1 - ema)
            # score_bank[tar_idx] = score_bank[tar_idx] * args.ema + softmax_out.detach() * (1 - args.ema)

            # fea_labeled[labeled_idx] = output_norm.detach().clone().cpu()
            # score_labeled[labeled_idx] = softmax_labeled.detach().clone()

            # class_cm[tar_idx] = output.detach().clone().cpu()
            # fea_tsne[tar_idx] = features_test.detach().clone().cpu()


            ind_keep[tar_idx] = 1


        loss = F.kl_div(softmax_out.log(), score_bank[tar_idx], reduction='batchmean')
        optimizer.zero_grad()
        entropy_loss = torch.mean(Entropy(softmax_out)) #torch.mean(Entropy(softmax_out))
        msoftmax = softmax_out.mean(dim=0)  #softmax_out.mean(dim=0)
        gentropy_loss = torch.sum(- msoftmax * torch.log(msoftmax + 1e-5))
        entropy_loss -= gentropy_loss
        loss += entropy_loss
        loss.backward(retain_graph=True)
        # # print(loss)
        if epoch_n >= args.max_epoch//2:
            loss += SimMaxLoss()(outputs_old) + SimMaxLoss()(softmax_out) + SimMinLoss()(outputs_old, softmax_out) * args.bbb
        else:
            one_hot = torch.argmax(diffusion_all[tar_idx], dim=1)
            loss += F.cross_entropy(softmax_out, one_hot) * args.aaa  # torch.mean(torch.abs(softmax_out - clip_all[tar_idx]))
            mask = torch.ones((inputs_target.shape[0], inputs_target.shape[0]))
            diag_num = torch.diag(mask)
            mask_diag = torch.diag_embed(diag_num)
            mask = mask - mask_diag
            copy = softmax_out.T   # .detach().clone()  #
            dot_neg = softmax_out @ copy  # batch x batch
            dot_neg = (dot_neg * mask.cuda()).sum(-1)  # batch
            neg_pred = torch.mean(dot_neg)
            loss += neg_pred * args.ccc


        # optimizer.zero_grad()
        # optimizer_oldC.zero_grad()
        loss.backward()

        # optimizer_oldC.step()

        # for param_oldC, param_netC in zip(oldC.parameters(), netC.parameters()):
        #     param_netC.grad = param_oldC.grad * 0.3 + param_netC.grad * 0.7

        optimizer.step()

        if args.warmup and iter_num!=0 and max_iter % iter_num == 0:
            scheduler.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netC.eval()
            # oldC.eval()

            # print("target")
            acc1, _ = cal_acc_(dset_loaders["test"], netF, netC)  # 1
            # print("source")
            log_str = "Task: {}, Iter:{}/{}; Accuracy on target = {:.2f}%".format(
                args.dset, iter_num, max_iter, acc1 * 100
            )
            args.out_file.write(log_str + "\n")
            args.out_file.flush()
            print(log_str,len(ind_keep == 1))

            if acc1 >= acc_init:
                acc_init = acc1
                # np.savetxt('best_fea_CVH.txt', fea_tsne)
                # np.savetxt('best_cm.txt', class_cm)
                best_netF = netF.state_dict()
                best_netC = netC.state_dict()

                torch.save(best_netF, osp.join(args.output_dir, "F_final.pt"))
                torch.save(best_netC, osp.join(args.output_dir, "C_final.pt"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Domain Adaptation on office-home dataset"
    )
    parser.add_argument(
        "--gpu_id", type=str, nargs="?", default="1", help="device id to run"
    )
    parser.add_argument('--norm-mean', type=float, nargs='+',
                        default=(0.485, 0.456, 0.406), help='normalization mean')
    parser.add_argument('--norm-std', type=float, nargs='+',
                        default=(0.229, 0.224, 0.225), help='normalization std')
    parser.add_argument('--mask_block_size', default=32, type=int)
    parser.add_argument('--mask_ratio', default=0.5, type=float)
    parser.add_argument('--mask_blur', default=False, type=bool)
    parser.add_argument("--s", type=int, default=0, help="source")
    parser.add_argument("--t", type=int, default=1, help="target")
    parser.add_argument("--max_epoch", type=int, default=100, help="maximum epoch")
    parser.add_argument("--batch_size", type=int, default=64, help="batch_size")
    parser.add_argument("--interval", type=int, default=15)
    parser.add_argument("--worker", type=int, default=8, help="number of workers")
    parser.add_argument("--dset", type=str, default="a2r")
    parser.add_argument("--choice", type=str, default="shot")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
    parser.add_argument("--seed", type=int, default=2021, help="random seed")
    parser.add_argument("--class_num", type=int, default=65)
    parser.add_argument("--par", type=float, default=0.1)
    parser.add_argument("--bottleneck", type=int, default=256)
    parser.add_argument("--layer", type=str, default="wn", choices=["linear", "wn"])
    parser.add_argument("--classifier", type=str, default="bn", choices=["ori", "bn"])
    parser.add_argument("--smooth", type=float, default=0.1)
    parser.add_argument("--output", type=str, default="weight")  # trainingC_2
    parser.add_argument("--file", type=str, default="noDIV")
    parser.add_argument("--home", action="store_true")
    parser.add_argument("--multi", action="store_true", default=False)
    parser.add_argument("--NOneg", default=False, action="store_true")
    parser.add_argument("--affi_neg", default=False, action="store_true")
    parser.add_argument("--ori", default=False, action="store_true")
    parser.add_argument("--no2hop", default=False, action="store_true")
    parser.add_argument("--onlyNN", default=False, action="store_true")
    parser.add_argument("--self", default=False, action="store_true")
    parser.add_argument("--cc", default=False, action="store_true")
    parser.add_argument("--alpha_decay", default=False, action="store_true")
    parser.add_argument("--alpha", type=float, default=1.0)
    parser.add_argument("--beta", type=float, default=0.75)
    parser.add_argument("--topKNEG", default=False, action="store_true")
    parser.add_argument("--conf", default=False, action="store_true")
    parser.add_argument("--lr_decay", default=False, action="store_true")
    parser.add_argument("--r_batch", default=False, action="store_true")
    parser.add_argument("--noGRAD", default=False, action="store_true")
    parser.add_argument("--sharp", default=False, action="store_true")
    parser.add_argument("--sharp_neg", default=False, action="store_true")
    parser.add_argument("--filter_beta", type=float, default=0.05)
    parser.add_argument("--decay_ema", action="store_false", default=True)
    parser.add_argument("--warmup_epochs", type=int, default=3)
    parser.add_argument("--warmup", action="store_true", default=False)
    parser.add_argument('--ema', type=float, default=0.95)
    parser.add_argument("--align", action="store_false", default=True)
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    torch.backends.cudnn.deterministic = True

    current_folder = "./"
    args.output_dir = osp.join(
        current_folder, args.output, "seed" + str(2024), args.dset
    )
    if not osp.exists(args.output_dir):
        os.system("mkdir -p " + args.output_dir)
    args.out_file = open(osp.join(args.output_dir, args.file + ".txt"), "w")
    args.out_file.write(print_args(args) + "\n")
    args.out_file.flush()
    # if args.alpha_decay:
    train_target_decay(args)
    """else:
        train_target(args)"""
    # if args.file=='cluster':

    # train_target_NN_oriNCE(args)
