import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import glob
import sys
import math
import time
import datetime
import shutil

import numpy as np

import torch
torch.use_deterministic_algorithms(True)

import torch.nn.utils as utils
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
from torchvision.utils import save_image
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
from decimal import Decimal
from tqdm import tqdm
import cv2

import utility
from basicsr.archs.quantize import QConv2d
import kornia as K


class Trainer():
    def __init__(self, args, loader, my_model, ckp, logger,opt,tb_logger):
        self.args = args
        self.scale = args.scale
        self.ckp = ckp
        # self.loader_train = loader[0]
        # self.loader_train_sampler = loader[1]
        # self.loader_init = loader[2]
        # self.loader_init_sampler = loader[3]
        self.loader_test = loader
        # self.num_iter_per_epoch_init = loader[5]
        # self.num_iter_per_epoch_train = loader[6]
        # self.prefetcher_train = prefetcher_train
        # self.prefetcher_init = prefetcher_init

        self.model = my_model
        self.epoch = 0
        self.opt = opt
        self.logger = logger
        self.tb_logger = tb_logger

        shutil.copyfile('./trainer.py', os.path.join(self.ckp.dir, 'trainer.py'))
        #shutil.copyfile('./quantize.py', os.path.join(self.ckp.dir, 'quantize.py'))

        quant_params_a = [v for k, v in self.model.net_g.named_parameters() if '_a' in k]
        quant_params_w = [v for k, v in self.model.net_g.named_parameters() if '_w' in k]

        if args.layerwise or args.videowise:
            quant_params_measure= []
            if args.layerwise:
                quant_params_measure_layer = [v for k, v in self.model.net_g.named_parameters() if 'measure_layer' in k]
                quant_params_measure.append({'params': quant_params_measure_layer, 'lr': args.lr_measure_layer})
            if args.videowise:
                quant_params_measure_image = [v for k, v in self.model.net_g.named_parameters() if 'measure' in k and 'measure_layer' not in k]
                quant_params_measure.append({'params': quant_params_measure_image, 'lr': args.lr_measure_video})
        
            self.optimizer_measure = torch.optim.Adam(quant_params_measure, betas=args.betas, eps=args.epsilon)
            self.scheduler_measure = lrs.StepLR(self.optimizer_measure, step_size=args.step, gamma=args.gamma)

        self.optimizer_a = torch.optim.Adam(quant_params_a, lr=args.lr_a, betas=args.betas, eps=args.epsilon)
        self.optimizer_w = torch.optim.Adam(quant_params_w, lr=args.lr_w, betas=args.betas, eps=args.epsilon)
        self.scheduler_a = lrs.StepLR(self.optimizer_a, step_size=args.step, gamma=args.gamma)
        self.scheduler_w = lrs.StepLR(self.optimizer_w, step_size=args.step, gamma=args.gamma)
        
        self.skt_losses = utility.AverageMeter()
        self.pix_losses = utility.AverageMeter()
        self.bit_losses = utility.AverageMeter()

        self.num_quant_modules = 0
        for n, m in self.model.net_g.named_modules():
            # print(type(m),'*****',isinstance(m, QConv2d))
            if isinstance(m, QConv2d):
                if not m.to_8bit: # 8-bit (first or last) modules are excluded for the bit count
                    self.num_quant_modules +=1
            # print('num_quant_modules:',self.num_quant_modules)

        # for initialization
        if not args.test_only:
            for n, m in self.model.net_g.named_modules():

                if isinstance(m, QConv2d):
                    setattr(m, 'w_bit', 32.0)
                    setattr(m, 'a_bit', 32.0)
                    setattr(m, 'init', True)
            if args.videowise:
                setattr(self.model.net_g.module, 'init', True)
    
    def set_bit(self, teacher=False):
        for n, m in self.model.net_g.named_modules():
            if isinstance(m, QConv2d):
                # print('teacher:', teacher)
                if teacher:
                    setattr(m, 'w_bit', 32.0)
                    setattr(m, 'a_bit', 32.0)
                elif m.non_adaptive:
                    if m.to_8bit:
                        setattr(m, 'w_bit', 8.0)
                        setattr(m, 'a_bit', 8.0)
                    else:

                        setattr(m, 'w_bit', self.args.quantize_w)
                        setattr(m, 'a_bit', self.args.quantize_a)
                else:
                    # print('update successfully!!! not teacher')
                    setattr(m, 'w_bit', self.args.quantize_w)
                    # print('Update value:',m.w_bit)
                    setattr(m, 'a_bit', self.args.quantize_a)

                setattr(m, 'init', False)

        if self.args.videowise:
            setattr(self.model.net_g.module, 'init', False)

    
    def patch_inference(self, model, lr, idx_scale):
        patch_idx = 0
        tot_bit_image = 0
        if self.args.n_parallel!=1: 
            lr_list, num_h, num_w, h, w = utility.crop_parallel(lr, self.args.test_patch_size, self.args.test_step_size)
            sr_list = torch.Tensor().cuda()
            for lr_sub_index in range(len(lr_list)// self.args.n_parallel + 1):
                torch.cuda.empty_cache()
                with torch.no_grad():
                    sr_sub, feat, bit = self.model(lr_list[lr_sub_index* self.args.n_parallel: (lr_sub_index+1)*self.args.n_parallel], idx_scale)
                    sr_sub = utility.quantize(sr_sub, self.args.rgb_range)
                sr_list = torch.cat([sr_list, sr_sub])
                average_bit = bit.mean() / self.num_quant_modules
                tot_bit_image += average_bit
                patch_idx += 1
            sr = utility.combine(sr_list, num_h, num_w, h, w, self.args.test_patch_size, self.args.test_step_size, self.scale[0])
        else:
            lr_list, num_h, num_w, h, w = utility.crop(lr, self.args.test_patch_size, self.args.test_step_size)
            sr_list = []
            for lr_sub_img in lr_list:
                torch.cuda.empty_cache()
                with torch.no_grad():
                    sr_sub, feat, bit = self.model(lr_sub_img, idx_scale)
                    sr_sub = utility.quantize(sr_sub, self.args.rgb_range)
                sr_list.append(sr_sub)
                average_bit = bit.mean() / self.num_quant_modules
                tot_bit_image += average_bit
                patch_idx += 1
            sr = utility.combine(sr_list, num_h, num_w, h, w, self.args.test_patch_size, self.args.test_step_size, self.scale[0])

        bit = tot_bit_image / patch_idx

        return sr, feat, bit

    def test(self):
        torch.set_grad_enabled(False)
        
        # if True:
        if True or self.epoch > 1 or self.args.test_only:
            self.ckp.write_log('\nEvaluation:')
            self.ckp.add_log(
                torch.zeros(1, len(self.loader_test), len(self.scale))
            )
            self.model.net_g.eval()
            timer_test = utility.timer()
        
            # if self.epoch == 2 or self.args.test_only:
            if self.epoch  or self.args.test_only:
                ################### Num of Params, Storage Size ####################
                n_params = 0
                n_params_q = 0
                for k, v in self.model.net_g.named_parameters():
                    nn = np.prod(v.size())#prod = production*
                    n_params += nn

                    if 'weight' in k:
                        name_split = k.split(".")
                        del name_split[-1]
                        module_temp = self.model.net_g
                        for n in name_split:
                            module_temp = getattr(module_temp, n)
                        if isinstance(module_temp, QConv2d):
                            n_params_q += nn * module_temp.w_bit / 32.0
                            # print(k, module_temp.w_bit)
                        else:
                            n_params_q += nn
                    else:
                        n_params_q += nn

                self.ckp.write_log('Parameters: {:.3f}K'.format(n_params/(10**3)))
                self.logger.info('Parameters: {:.3f}K'.format(n_params/(10**3)))
                self.ckp.write_log('Model Size: {:.3f}K'.format(n_params_q/(10**3)))
                self.logger.info('Model Size: {:.3f}K'.format(n_params_q/(10**3)))
        
            if self.args.save_results:
                self.ckp.begin_background()
        
            ############################## TEST FOR OWN #############################
            if self.args.test_own is not None:
                test_img = cv2.imread(self.args.test_own)
                lr = torch.tensor(test_img).permute(2,0,1).float().cuda()
                lr = torch.flip(lr, (0,)) # for color
                lr = lr.unsqueeze(0)

                tot_bit = 0
                for idx_scale, scale in enumerate(self.scale):
                    if self.args.test_patch:
                        sr, feat, bit = self.patch_inference(self.model, lr, idx_scale)
                        img_bit = bit
                    else:
                        with torch.no_grad():
                            sr, feat, bit = self.model(lr, idx_scale)
                        img_bit = bit.mean() / self.num_quant_modules

                    sr = utility.quantize(sr, self.args.rgb_range)
                    save_list = [sr]


                    filename = self.args.test_own.split('/')[-1].split('.')[0]
                    if self.args.save_results:
                        save_name = '{}_x{}_{:.2f}bit'.format(filename, scale, img_bit)
                        self.ckp.save_results('test_own', save_name, save_list)

                    self.ckp.write_log('[{} x{}] Average Bit: {:.2f} '.format(filename, scale, img_bit))

            ############################## TEST FOR TEST SET #############################
            if self.args.test_own is None:
                for val_loader in self.loader_test:

                    self.set_bit(teacher=False)
                    with torch.no_grad():
                        self.model.validation(val_loader, self.epoch, self.tb_logger, self.opt['val']['save_img'], self.num_quant_modules)


                if self.tb_logger:
                    self.tb_logger.close()

                        
            if self.args.save_results:
                self.ckp.end_background()
            
            # save models
            # if not self.args.test_only:
            #     self.ckp.save(self, self.epoch, is_best=(best[1][0, 0] + 1 == self.epoch -1))

        torch.set_grad_enabled(True) 

    def test_teacher(self):
        torch.set_grad_enabled(False)
        self.model.eval()
        self.ckp.write_log('Teacher Evaluation')

        ############################## Num of Params ####################
        n_params = 0
        for k, v in self.model.named_parameters():
            if '_a' not in k and '_w' not in k and 'measure' not in k: # for teacher model
                n_params += np.prod(v.size())
        self.ckp.write_log('Parameters: {:.3f}K'.format(n_params/(10**3)))

        if self.args.save_results:
            self.ckp.begin_background()
        
        ############################## TEST FOR OWN #############################
        if self.args.test_own is not None:
            test_img = cv2.imread(self.args.test_own)
            lr = torch.tensor(test_img).permute(2,0,1).float().cuda()
            lr = torch.flip(lr, (0,)) # for color
            lr = lr.unsqueeze(0)

            tot_bit = 0
            for idx_scale, scale in enumerate(self.scale):
                self.set_bit(teacher=True)
                if self.args.test_patch:
                    sr, feat, bit = self.patch_inference(self.model, lr, idx_scale)
                    img_bit = bit
                else:
                    with torch.no_grad():
                        sr, feat, bit = self.model(lr, idx_scale)
                    img_bit = bit.mean() / self.num_quant_modules                    
                self.set_bit(teacher=False)

                sr = utility.quantize(sr, self.args.rgb_range)
                save_list = [sr]
                if self.args.save_results:
                    filename = self.args.test_own.split('/')[-1].split('.')[0]
                    save_name = '{}_x{}_{:.2f}bit'.format(filename, scale, img_bit)
                    self.ckp.save_results('test_own', save_name, save_list)

        ############################## TEST FOR TEST SET #############################
        if self.args.test_own is None:
            for idx_data, d in enumerate(self.loader_test):
                for idx_scale, scale in enumerate(self.scale):
                    d.dataset.set_scale(idx_scale)
                    tot_ssim =0
                    tot_bit =0 
                    tot_psnr =0.0
                    i=0
                    for lr, hr, filename in tqdm(d, ncols=80):
                        i+=1
                        lr, hr = self.prepare(lr, hr)
                        self.set_bit(teacher=True)
                        if self.args.test_patch:
                            sr, feat, bit = self.patch_inference(self.model, lr, idx_scale)
                            img_bit = bit
                        else:
                            with torch.no_grad():
                                sr, feat, bit = self.model(lr, idx_scale)
                            img_bit = bit.mean() / self.num_quant_modules
                        self.set_bit(teacher=False)

                        sr = utility.quantize(sr, self.args.rgb_range)
                        save_list = [sr]
                        psnr, ssim = utility.calc_psnr(sr, hr, scale, self.args.rgb_range, dataset=d)

                        tot_bit += img_bit
                        tot_psnr += psnr
                        tot_ssim += ssim

                        if self.args.save_gt:
                            save_list.extend([lr, hr])

                        if self.args.save_results:
                            save_name = '{}_x{}_{:.2f}dB'.format(filename[0], scale, cur_psnr)
                            self.ckp.save_results(d, save_name, save_list)

                    tot_psnr /= len(d)
                    tot_ssim /= len(d)
                    tot_bit /= len(d)

                    self.ckp.write_log(
                        '[{} x{}]\tPSNR: {:.3f} \t SSIM: {:.4f} \tBit: {:.2f}'.format(
                            d.dataset.name,
                            scale,
                            tot_psnr,
                            tot_ssim,
                            tot_bit.item(),
                        )
                    )

        if self.args.save_results:
            self.ckp.end_background()

        torch.set_grad_enabled(True)



    def prepare(self, *args):
        device = torch.device('cpu' if self.args.cpu else 'cuda')
        def _prepare(tensor):
            if self.args.precision == 'half': tensor = tensor.half()
            return tensor.to(device)

        return [_prepare(a) for a in args]

    def terminate(self):
        if self.args.test_only:
            self.test()
            return True
        else:
            # return self.epoch >= self.args.epochs
            return self.epoch > self.args.epochs

import functools
import torch
from torch.nn import functional as F


def reduce_loss(loss, reduction):
    """Reduce loss as specified.

    Args:
        loss (Tensor): Elementwise loss tensor.
        reduction (str): Options are 'none', 'mean' and 'sum'.

    Returns:
        Tensor: Reduced loss tensor.
    """
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, elementwise_mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    else:
        return loss.sum()


def weight_reduce_loss(loss, weight=None, reduction='mean'):
    """Apply element-wise weight and reduce loss.

    Args:
        loss (Tensor): Element-wise loss.
        weight (Tensor): Element-wise weights. Default: None.
        reduction (str): Same as built-in losses of PyTorch. Options are
            'none', 'mean' and 'sum'. Default: 'mean'.

    Returns:
        Tensor: Loss values.
    """
    # if weight is specified, apply element-wise weight
    if weight is not None:
        assert weight.dim() == loss.dim()
        assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
        loss = loss * weight

    # if weight is not specified or reduction is sum, just reduce the loss
    if weight is None or reduction == 'sum':
        loss = reduce_loss(loss, reduction)
    # if reduction is mean, then compute mean over weight region
    elif reduction == 'mean':
        if weight.size(1) > 1:
            weight = weight.sum()
        else:
            weight = weight.sum() * loss.size(1)
        loss = loss.sum() / weight

    return loss


def weighted_loss(loss_func):
    """Create a weighted version of a given loss function.

    To use this decorator, the loss function must have the signature like
    `loss_func(pred, target, **kwargs)`. The function only needs to compute
    element-wise loss without any reduction. This decorator will add weight
    and reduction arguments to the function. The decorated function will have
    the signature like `loss_func(pred, target, weight=None, reduction='mean',
    **kwargs)`.

    :Example:

    >>> import torch
    >>> @weighted_loss
    >>> def l1_loss(pred, target):
    >>>     return (pred - target).abs()

    >>> pred = torch.Tensor([0, 2, 3])
    >>> target = torch.Tensor([1, 1, 1])
    >>> weight = torch.Tensor([1, 0, 1])

    >>> l1_loss(pred, target)
    tensor(1.3333)
    >>> l1_loss(pred, target, weight)
    tensor(1.5000)
    >>> l1_loss(pred, target, reduction='none')
    tensor([1., 1., 2.])
    >>> l1_loss(pred, target, weight, reduction='sum')
    tensor(3.)
    """

    @functools.wraps(loss_func)
    def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
        # get element-wise loss
        loss = loss_func(pred, target, **kwargs)
        loss = weight_reduce_loss(loss, weight, reduction)
        return loss

    return wrapper


@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):
    return torch.sqrt((pred - target)**2 + eps)