import os
import sys
import math
import time
import copy
import faiss
import random

import torch as th
import numpy as np
from einops import rearrange
import torchvision.utils as tvu
import torch.utils.data as data
import torch.distributed as dist
import torchvision.transforms.functional as F
from torchvision import transforms
from sklearn.metrics import roc_auc_score
from tqdm.auto import tqdm

from dataset import get_dataset
from runner.runner import Runner
from model.ood.resnet18 import ResNet18_32x32
from model.ood.resnet50 import ResNet50
from detect.dataset.imglist_dataset import ImglistDataset


@th.no_grad()
class KNN(object):
    def __init__(self, dim, net=None):
        self.dim = dim
        self.index = faiss.IndexFlatL2(dim)
        self.net = net
        self.y = None

    def encoder(self, y: th.Tensor):
        normalizer = lambda x: x / (np.linalg.norm(x, axis=-1, keepdims=True) + 1e-10)

        if self.net is not None:
            y = th.clamp(y * 0.5 + 0.5, 0, 1)  # ñ׼룡
            _, index = self.net(y, return_feature=True)
            index = index.cpu().numpy()
            # print(index.shape)
        else:
            index = y.mean(dim=1).cpu().numpy().reshape(len(y), -1)

        index = normalizer(index)
        assert index.shape[1] == self.dim

        return index

    def add(self, y: th.Tensor):
        index = self.encoder(y)

        self.index.add(index)
        self.y = th.cat([self.y, y], dim=0) if self.y is not None else y
        print('the shape of y in KNN', y.shape)

    def search(self, y: th.Tensor, k=1, return_y=False):
        index = self.encoder(y)

        if return_y:
            loss, ind = self.index.search(index, k)
            y = rearrange(self.y[ind.reshape(-1)], '(b1 b2) ... -> b1 b2 ...', b2=k)  # todo check
            return loss, ind, y
        else:
            return self.index.search(index, k)


class OodDetection(Runner):
    def __init__(self, args, config, schedule, model):
        super(OodDetection, self).__init__(args, config, schedule, model)
        # self-train version
        self.discriminator = ResNet18_32x32(num_classes=10).to(self.device)
        state_dict = th.load('temp/model/ood_cifar10_res18.ckpt', map_location=self.device)
        try:
            self.discriminator.load_state_dict(state_dict, strict=True)
        except RuntimeError:
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                new_state_dict[k[7:]] = v
            self.discriminator.load_state_dict(new_state_dict, strict=True)
        self.discriminator.eval()  # ֵܣҪeval

    @th.no_grad()
    def noise_encoder(self):
        """
        generate image representation
        """
        model = self.model
        schedule = self.schedule
        device = self.device
        continuous = False

        seq, skip, train_loader = self.before_sample()

        image_list, noise_list = [], []

        def gather(obj):
            if self.world_size >= 2:
                obj = obj.cuda()
                obj_gather = [th.zeros_like(obj) for _ in range(self.world_size)]
                dist.all_gather(obj_gather, obj)
                obj = th.cat(obj_gather, dim=0).cpu()
            return obj

        for (img, y) in tqdm(train_loader, disable=self.rank == self.world_size - 1):
            img = img.to(device) * 2 - 1
            noise_repr = schedule.multi_iteration(img, -1, 980 - 1, model,
                                                  last=True, fresh=True, continuous=continuous)

            image_list.append(img.cpu())
            noise_list.append(noise_repr.cpu())

        image_list = th.cat(image_list, dim=0)
        noise_list = th.cat(noise_list, dim=0)
        if self.world_size >= 2:
            image_list = gather(image_list).numpy()
            noise_list = gather(noise_list).numpy()

        if self.rank == 0:
            print(image_list.shape, noise_list.shape)
            # np.save(f'{self.args.image_path}/{self.args.category}_{self.args.category_value}.npy', noise)
            np.save(f'temp/noise/{self.args.model_name}_img_1.npy', image_list)
            np.save(f'temp/noise/{self.args.model_name}_noise_1.npy', noise_list)

        # for k in tqdm(range(4, 50, 5), desc='gen_edit'):
        #     t = seq[k] * th.tensor([1] * repeat_size).to(device)
        #     # print(img.shape, t.shape)
        #     img_n, _, _ = schedule.diffusion(img, t, noise=noise)
        #
        #     noise_r = schedule.multi_iteration(img_n, k * skip - 1, 49 * skip - 1, model,
        #                                        last=True, fresh=True, continuous=continuous)
        #     img_r = schedule.multi_iteration(img_n, k * skip - 1, 0 * skip - 1, model,
        #                                      last=True, fresh=True, continuous=continuous)
        #
        #     img_r = th.clamp(img_r * 0.5 + 0.5, 0, 1)
        #     noise_r = th.clamp(noise_r * 0.5 + 0.5, 0, 1)
        #     for i in range(repeat_size):
        #         tvu.save_image(img_r[i], os.path.join(self.args.image_path,
        #                                               f"img-{i + 1}-{k}.png"))
        #         tvu.save_image(noise_r[i], os.path.join(self.args.image_path,
        #                                                 f"noise-{i + 1}-{k}.png"))

    @th.no_grad()
    def enhancement(self):
        """
        test noise enhancement
        """
        test_size = 2

        model = self.model
        schedule = self.schedule
        device = self.device
        continuous = False

        seq, skip, train_loader = self.before_sample()

        cat_np = []
        for (img, y) in train_loader:
            for i in range(len(y)):
                if y[i] == 3:
                    cat_np.append(img[i].numpy())
            break

        np.save('temp/cat_np.npy', cat_np)
        # img = img[:test_size].to(device) * 2 - 1
        #
        # noise = schedule.multi_iteration(img, - 1, 49 * skip - 1, model,
        #                                  last=True, fresh=True, continuous=continuous)
        #
        # noise = transforms.RandomHorizontalFlip(p=1)(noise)
        # img_r = schedule.multi_iteration(noise, 49 * skip - 1, - 1, model,
        #                                  last=True, fresh=True, continuous=continuous)
        #
        # img = th.clamp(img * 0.5 + 0.5, 0, 1)
        # img_r = th.clamp(img_r * 0.5 + 0.5, 0, 1)
        # for i in range(test_size):
        #     tvu.save_image(img[i], os.path.join(self.args.image_path,
        #                                         f"img-{i + 1}.png"))
        #     tvu.save_image(img_r[i], os.path.join(self.args.image_path,
        #                                           f"img_r-{i + 1}.png"))
        #
        # break

    @th.no_grad()
    def interp_detect(self):
        batch_size = 250
        iter_size = 4
        iter_size = iter_size // self.world_size if self.world_size >= 2 else iter_size
        repeat_size = self.args.repeat_size
        knn_num = self.args.debug_value
        id_name = self.args.model_name

        model = self.model
        schedule = self.schedule
        device = self.device
        continuous = False

        # load model
        self.before_sample()

        # seq, skip, train_loader = self.before_sample()
        # # cifar10: ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

        # img_list = [[] for _ in range(10)]
        # for j, (img, y) in enumerate(tqdm(train_loader)):
        #     img = img.to(device) * 2 - 1
        #     for i in range(len(y)):
        #         img_list[y[i]].append(img[i])
        #
        #     if j == 9:
        #         break
        # img_list = [th.stack(img_list[i], dim=0) for i in range(len(img_list))]
        # img_index = th.cat(img_list, dim=0)

        # # prepare KNN
        # image_np, noise_np = np.load(f'temp/noise/{id_name}_img_1.npy'), np.load(f'temp/noise/{id_name}_noise_1.npy')
        # index = KNN(image_np.shape[-1] ** 2, None)  # 512, self.discriminator
        # ind = th.randperm(len(image_np))[:10000]  # ظѡȡ
        # image_np, noise_np = image_np[ind], noise_np[ind]
        # index.add(th.from_numpy(image_np))
        # index.y = th.from_numpy(noise_np)

        def gather(obj):
            if self.world_size >= 2:
                obj = obj.cuda()
                obj_gather = [th.zeros_like(obj) for _ in range(self.world_size)]
                dist.all_gather(obj_gather, obj)
                obj = th.cat(obj_gather, dim=0).cpu()
            return obj

        # load & process dataset
        ood_list, ood_dict = ['svhn', 'texture', 'places365', 'cifar100', 'cifar10', 'tin'], {}
        # ood_list, ood_dict = ['imagenet', 'inaturalist', 'openimage_o', 'imagenet_o', 'species'], {}

        for ood_name in ood_list:
            dataset = ImglistDataset(id_name, 'test', 32,
                                     f'./data/benchmark_imglist/{id_name}/test_{ood_name}.txt',
                                     f'./data/images_classic/')
            if self.world_size >= 2:
                sampler = data.distributed.DistributedSampler(dataset)
                ood_loader = data.DataLoader(dataset, batch_size=batch_size, sampler=sampler,
                                             num_workers=4)
            else:
                ood_loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
                                             num_workers=4)
            # print(len(dataset))

            norm_dict = {'cifar10': [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]],
                         'cifar100': [[0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]],
                         'imagenet': [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]], }
            mean, std = norm_dict[id_name]
            norm_fn = transforms.Normalize(mean=mean, std=std)

            imgr_np, fea_np, out_np = [[] for _ in range(2)], [[] for _ in range(2)], [[] for _ in range(2)]
            for i, output in enumerate(tqdm(ood_loader, total=iter_size, disable=self.rank + 1 - self.world_size,
                                            desc=f'process {ood_name} data')):
                if i == iter_size:
                    for j in range(len(imgr_np)):
                        imgr_np[j] = th.cat(imgr_np[j], dim=0)
                        fea_np[j] = th.cat(fea_np[j], dim=0)
                        out_np[j] = th.cat(out_np[j], dim=0)
                        imgr_np[j] = gather(imgr_np[j]).numpy()
                        fea_np[j] = gather(fea_np[j]).numpy()
                        out_np[j] = gather(out_np[j]).numpy()

                    np.savez(f'temp/sample_ood/{id_name}_{ood_name}_knn{knn_num}_{repeat_size}.npz', imgr=imgr_np,
                             fea=fea_np, out=out_np)
                    break

                img = output['data']
                ood_img = img.to(device) * 2 - 1

                # loss, ind = index.search(ood_img, repeat_size+1)
                # noise = index.y[ind[:, -1].reshape(-1)]
                # noise = th.randn_like(ood_img)
                noise_list = [th.randn_like(ood_img) for i in range(repeat_size)]
                out_ = self.discriminator(norm_fn(img.cuda()), return_feature=False)
                score = th.softmax(out_, dim=1)
                _, y_pred = th.max(score, dim=1)

                tq = tqdm(total=8 * repeat_size, leave=False, desc='subprocess',
                          disable=self.rank + 1 - self.world_size)
                for j, t in enumerate([0, 240]):  # 1000
                    imgr_rp, fea_rp, out_rp = [], [], []
                    for k in range(repeat_size):
                        tq.update()

                        img_n, _, _ = schedule.diffusion(ood_img, th.ones(batch_size, device=device, dtype=th.long) * t,
                                                         noise=noise_list[k].cuda())
                        # img_n = slerp(noise1, noise, t / 1000.0)
                        img_r = schedule.multi_iteration(img_n, t - 1, -1, model, y=y_pred,  # ǵöӦģ
                                                         last=True, fresh=True, continuous=continuous)
                        img_r = th.clamp(img_r * 0.5 + 0.5, 0, 1)
                        img_r = norm_fn(img_r)
                        logit, feature = self.discriminator(img_r.cuda(), return_feature=True)

                        imgr_rp.append(img_r.cpu())
                        out_rp.append(logit.cpu())
                        fea_rp.append(feature.cpu())

                    imgr = rearrange(th.stack(imgr_rp, dim=1), 'b r ... -> (b r) ...')
                    out = rearrange(th.stack(out_rp, dim=1), 'b r ... -> (b r) ...')
                    fea = rearrange(th.stack(fea_rp, dim=1), 'b r ... -> (b r) ...')
                    imgr_np[j].append(imgr)
                    fea_np[j].append(fea)  # дĴܷ
                    out_np[j].append(out)
                tq.close()

        sys.exit()

        self.config['Dataset']['dataset'] = 'CIFAR10'
        _, test_dataset = get_dataset(self.args, self.config['Dataset'])
        ood_loader = data.DataLoader(test_dataset, batch_size=self.config['Sample']['batch_size'], shuffle=True,
                                     num_workers=self.config['Dataset']['num_workers'])
        ood_list.append('cifar10t')
        for (img, y) in ood_loader:
            img = img.to(device) * 2 - 1
            ood_dict['cifar10t'] = img

            break

        # start OOD detection
        knn_num = self.args.debug_value
        print('knn_num', knn_num)

        def slerp(z1, z2, alpha):
            theta = th.acos(th.sum(z1 * z2) / (th.norm(z1) * th.norm(z2)))
            return (th.sin((1 - alpha) * theta) / th.sin(theta) * z1
                    + th.sin(alpha * theta) / th.sin(theta) * z2)

        ind = th.randperm(len(img_index))[:test_size]  # ظѡȡ
        id_img = img_index[ind]
        loss, ind = index.search(id_img, knn_num)
        noise = index.y[ind[:, -1].reshape(-1)]
        # noise = schedule.multi_iteration(img, -1, 49 * skip-1, model,
        #                                  last=True, fresh=True, continuous=continuous)
        # noise1 = schedule.multi_iteration(id_img, 0, 49 * skip, model,
        #                                   last=True, fresh=True, continuous=continuous)
        # noise = th.randn_like(id_img)

        imgr_np, fea_np, out_np = [], [], []
        for t in range(0, 240, 40):  # 1000
            img_n, _, _ = schedule.diffusion(id_img, th.ones(test_size, device=device, dtype=th.long) * t,
                                             noise=noise.cuda())
            # img_n = slerp(noise1, noise, t / 1000.0)
            img_r = schedule.multi_iteration(img_n, t - 1, -1, model,  # ǵöӦģ
                                             last=True, fresh=True, continuous=continuous)

            img_r = th.clamp(img_r * 0.5 + 0.5, 0, 1)
            imgr_np.append(img_r.numpy())
            output, feature = self.discriminator(img_r.cuda(), return_feature=True)
            fea_np.append(feature.cpu().numpy())
            out_np.append(output.cpu().numpy())
            print(t, *output.cpu().max(dim=1))

            for i in range(8):
                tvu.save_image(img_r[i], os.path.join(self.args.image_path, f"img-{t}-{i + 1}.png"))

        # ori_img = img
        # ori_img = th.clamp(ori_img * 0.5 + 0.5, 0, 1)
        # imgr_np.append(ori_img.cpu().numpy())
        # output, feature = self.discriminator(ori_img.cuda(), return_feature=True)
        # fea_np.append(feature.cpu().numpy())
        # out_np.append(output.cpu().numpy())
        # print(1000, *output.cpu().max(dim=1))
        # for i in range(8):
        #     tvu.save_image(ori_img[i], os.path.join(self.args.image_path, f"img-1000-{i + 1}.png"))

        np.savez(f'temp/sample_ood/CIFAR10_knn{knn_num}.npz', imgr=imgr_np, fea=fea_np, out=out_np, loss=loss)

        for name in ood_list:
            print(name)
            ood_img = ood_dict[name][:test_size]

            loss, ind = index.search(ood_img, knn_num)
            noise = index.y[ind[:, -1].reshape(-1)]
            # img = img_index[ind[:, -1].reshape(-1)]  # todo ֧k>1
            # noise = schedule.multi_iteration(img, -1, 49 * skip -1, model,
            #                                  last=True, fresh=True, continuous=continuous)
            # noise1 = schedule.multi_iteration(ood_img, 0, 49 * skip, model,
            #                                   last=True, fresh=True, continuous=continuous)
            # noise = th.randn_like(ood_img)

            imgr_np, fea_np, out_np = [], [], []
            for t in range(0, 80, 10):  # 1000
                img_n, _, _ = schedule.diffusion(ood_img, th.ones(test_size, device=device, dtype=th.long) * t,
                                                 noise=noise.cuda())
                # img_n = slerp(noise1, noise, t / 1000.0)
                img_r = schedule.multi_iteration(img_n, t - 1, -1, model,
                                                 last=True, fresh=True, continuous=continuous)

                img_r = th.clamp(img_r * 0.5 + 0.5, 0, 1)
                imgr_np.append(img_r.numpy())
                output, feature = self.discriminator(img_r.cuda(), return_feature=True)
                fea_np.append(feature.cpu().numpy())
                out_np.append(output.cpu().numpy())
                print(t, *output.cpu().max(dim=1))

                for i in range(8):
                    tvu.save_image(img_r[i], os.path.join(self.args.image_path, f"img-{t}-{i + 1}.png"))

            # ori_img = img
            # ori_img = th.clamp(ori_img * 0.5 + 0.5, 0, 1)
            # imgr_np.append(ori_img.cpu().numpy())
            # output, feature = self.discriminator(ori_img.cuda(), return_feature=True)
            # fea_np.append(feature.cpu().numpy())
            # out_np.append(output.cpu().numpy())
            # print(1000, *output.cpu().max(dim=1))
            # for i in range(8):
            #     tvu.save_image(ori_img[i], os.path.join(self.args.image_path, f"img-1000-{i + 1}.png"))

            np.savez(f'temp/sample_ood/{name}_knn{knn_num}.npz', imgr=imgr_np, fea=fea_np, out=out_np, loss=loss)

    @th.no_grad()
    def test(self):
        img_np1 = np.load('temp/cat_np.npy')[9]
        img_np2 = np.load('temp/cat_np.npy')[0]
        image1 = th.from_numpy(img_np1).float().view(1, 3, 32, 32)
        image2 = th.from_numpy(img_np2).float().view(1, 3, 32, 32)
        # image = image.repeat(8, 1, 1, 1)
        image1 = image1.to(self.device) * 2 - 1
        image2 = image2.to(self.device) * 2 - 1

        model = self.model
        schedule = self.schedule
        device = self.device
        continuous = False

        seq, skip, train_loader = self.before_sample()

        # for size in range(4, 36, 4):
        # img = F.resize(F.resize(image, [size, size]), [32, 32])
        noise1 = schedule.multi_iteration(image1, - 1, 980 - 1, model,
                                          y=th.ones(1, dtype=th.long, device=self.device) * 3,
                                          last=True, fresh=True, continuous=continuous)
        noise2 = schedule.multi_iteration(image2, - 1, 980 - 1, model,
                                          y=th.ones(1, dtype=th.long, device=self.device) * 3,
                                          last=True, fresh=True, continuous=continuous)

        for i in range(1, 10):
            img_n2, _, _ = schedule.diffusion(image2, th.ones(1, device=device, dtype=th.long) * 100 * i,
                                              noise=noise1.cuda())
            img_r2 = schedule.multi_iteration(img_n2, 100 * i - 1, - 1, model,
                                              y=th.ones(1, dtype=th.long, device=self.device) * 3,
                                              last=True, fresh=True, continuous=continuous)

            img_n1, _, _ = schedule.diffusion(image1, th.ones(1, device=device, dtype=th.long) * 100 * i,
                                              noise=noise2.cuda())
            img_r1 = schedule.multi_iteration(img_n1, 100 * i - 1, - 1, model,
                                              y=th.ones(1, dtype=th.long, device=self.device) * 3,
                                              last=True, fresh=True, continuous=continuous)

            img_r1 = th.clamp(img_r1 * 0.5 + 0.5, 0, 1)
            img_r2 = th.clamp(img_r2 * 0.5 + 0.5, 0, 1)
            tvu.save_image(img_r1[0], os.path.join(self.args.image_path,
                                                   f"11-{i}.png"))
            tvu.save_image(img_r2[0], os.path.join(self.args.image_path,
                                                   f"22-{i}.png"))
            # img_r1 = schedule.multi_iteration(img_n, 200 - 1, 100 - 1, model,
            #                                   y=th.ones(8, dtype=th.long, device=self.device) * 3,
            #                                   last=True, fresh=True, continuous=continuous)
        sys.exit()
        # img_n, _, noise = schedule.diffusion(img, th.ones(len(image), device=device, dtype=th.long) * 200)
        # img_r1 = schedule.multi_iteration(img_n, 200 - 1, 100 - 1, model,
        #                                   y=th.ones(8, dtype=th.long, device=self.device) * 3,
        #                                   last=True, fresh=True, continuous=continuous)
        #
        # img_r2 = schedule.multi_iteration(img_r1, 100 - 1, -1, model,
        #                                   y=th.ones(8, dtype=th.long, device=self.device) * 3,
        #                                   last=True, fresh=True, continuous=continuous)
        #
        # # print('%.2f' % (img.cpu() - img_r).square().sum(dim=(1, 2, 3)).mean().item(), end=' ')
        # img = th.clamp(img * 0.5 + 0.5, 0, 1)
        # img_n = th.clamp(img_n * 0.5 + 0.5, 0, 1)
        # noise = th.clamp(noise * 0.5 + 0.5, 0, 1)
        # img_r1 = th.clamp(img_r1 * 0.5 + 0.5, 0, 1)
        # img_r2 = th.clamp(img_r2 * 0.5 + 0.5, 0, 1)
        # tvu.save_image(img[0], os.path.join(self.args.image_path,
        #                                     f"ori.png"))
        for i in range(8):
            tvu.save_image(img_n[i], os.path.join(self.args.image_path,
                                                  f"in-{i}.png"))
            tvu.save_image(img_r1[i], os.path.join(self.args.image_path,
                                                   f"r1-{i}.png"))
            tvu.save_image(noise[i], os.path.join(self.args.image_path,
                                                  f"n-{i}.png"))
            tvu.save_image(img_r2[i], os.path.join(self.args.image_path,
                                                   f"r2-{i}.png"))

        print()
        sys.exit()

        img_n, _, _ = schedule.diffusion(image, th.ones(len(image), device=device, dtype=th.long) * 300)
        print(img_n.shape)
        img_r = schedule.multi_iteration(img_n, 300 - 1, -1, model,
                                         last=True, fresh=True, continuous=continuous)

        img_n = th.clamp(img_n * 0.5 + 0.5, 0, 1).cpu().numpy()
        img_r = th.clamp(img_r * 0.5 + 0.5, 0, 1).cpu().numpy()
        np.savez('temp/cat_recon.npz', imgn=img_n, imgr=img_r)
