import utility
import torch
from tqdm import tqdm
import numpy as np
import copy
from data.bicubic import bicubic
import loss
import os
import dist_util
from model.lsq_plus import ActLSQ

class Trainer():
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.use_con = True
        self.optimizer = utility.make_optimizer(args, self.model)
        self.last_epoch = None
        self.bicubic = bicubic(args)
        self.error_last = 1e8
        self.criterion = my_loss
        
        # import pdb; pdb.set_trace()
        self.con_loss = loss.SupConLoss()
        # self.ssim_loss = loss.SSIMLoss()
    
    
    def pretrain(self, epoch):
    
        self.model.train()
        self.optimizer.zero_grad()
        # import pdb; pdb.set_trace()
        tmp_epoch_losses = 0
        batch_idx = 0
        
        # import pdb; pdb.set_trace()
        print('data length: ', len(self.loader_train))
        for batch_idx, imgs in enumerate(self.loader_train):
            
            # if batch_idx >= 0:
                # break
            
            # import pdb; pdb.set_trace()
            lr, hr, _ = imgs
            idx_scale = self.loader_train.dataset.idx_scale
            lr, hr = self.prepare(lr, hr)
            if self.use_con:
                
                ### to dtype
                hr = hr.to(torch.float32)
                lr = lr.to(torch.float32)
                sr, x_con = self.model(lr, idx_scale, self.use_con)
                loss1 = self.criterion(sr, hr)
                loss2 = self.con_loss(x_con)
                # loss3 = self.ssim_loss(sr, hr)
                loss = loss1 + 0.1 * loss2
                tmp_epoch_losses += loss
            
            self.optimizer.zero_grad()
            loss.backward()
            
            self.optimizer.step()
            
            if batch_idx % self.args.log_frequency == 0:
                print('Epoch: {}, Step: {}, Task: {}, tmp_epoch_avg_loss: {:.4f}, loss: {:.4f}, loss1: {:.4f}, loss2: {:.4f}'.format(epoch, batch_idx, idx_scale, tmp_epoch_losses / (batch_idx + 1), loss, loss1, loss2))
            
        self.optimizer.schedule()
        self.last_epoch = epoch

    def test(self, args):
        # import pdb; pdb.set_trace()
        torch.set_grad_enabled(False)

        epoch = self.optimizer.get_last_epoch()
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(
            torch.zeros(1, len(self.loader_test), len(self.scale))
        )
        self.model.eval()
        timer_test = utility.timer()
        if self.args.save_results: self.ckp.begin_background()
        for idx_data, d in enumerate(self.loader_test):
            # import pdb; pdb.set_trace()
            i = 0
            for idx_scale, scale in enumerate(self.scale):
                if idx_scale != args.set_task:
                    continue
                d.dataset.set_scale(idx_scale)
                if self.args.derain:
                    for norain, rain, filename in tqdm(d, ncols=80):
                        norain, rain = self.prepare(norain, rain)
                        sr = self.model(rain, idx_scale)
                        sr = utility.quantize(sr, self.args.rgb_range)
                        
                        save_list = [sr, rain, norain]
                        self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
                            sr, norain, scale, self.args.rgb_range
                        )
                        if self.args.save_results:
                            self.ckp.save_results(d, filename[0], save_list, 1)
                    self.ckp.log[-1, idx_data, idx_scale] /= len(d)
                    best = self.ckp.log.max(0)
                    self.ckp.write_log(
                        '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                            d.dataset.name,
                            scale,
                            self.ckp.log[-1, idx_data, idx_scale],
                            best[0][idx_data, idx_scale],
                            best[1][idx_data, idx_scale] + 1
                        )
                    )
                    isderain = 0
                elif self.args.denoise:
                    for hr, _,filename in tqdm(d, ncols=80):
                        hr = self.prepare(hr)[0]
                        noisy_level = self.args.sigma
                        noise = torch.randn(hr.size()).mul_(noisy_level).cuda()
                        nois_hr = (noise+hr).clamp(0,255)
                        sr = self.model(nois_hr, idx_scale)
                        sr = utility.quantize(sr, self.args.rgb_range)

                        save_list = [sr, nois_hr, hr]
                        self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
                            sr, hr, scale, self.args.rgb_range
                        )
                        if self.args.save_results:
                            self.ckp.save_results(d, filename[0], save_list, 50)

                    self.ckp.log[-1, idx_data, idx_scale] /= len(d)
                    best = self.ckp.log.max(0)
                    self.ckp.write_log(
                        '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                            d.dataset.name,
                            scale,
                            self.ckp.log[-1, idx_data, idx_scale],
                            best[0][idx_data, idx_scale],
                            best[1][idx_data, idx_scale] + 1
                        )
                    )
                else:
                    for lr, hr, filename in tqdm(d, ncols=80):
                        lr, hr = self.prepare(lr, hr)
                        sr = self.model(lr, idx_scale)
                        sr = utility.quantize(sr, self.args.rgb_range)

                        save_list = [sr]
                        self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
                            sr, hr, scale, self.args.rgb_range
                        )
                        
                        if self.args.save_gt:
                            save_list.extend([lr, hr])

                        if self.args.save_results:
                            self.ckp.save_results(d, filename[0], save_list, scale)
                        i = i+1
                    # import pdb; pdb.set_trace()
                    self.ckp.log[-1, idx_data, idx_scale] /= len(d)
                    best = self.ckp.log.max(0)
                    
                    if self.last_epoch is None:
                        self.last_epoch = 0
                        print('Now is the test_only mode !!!!!!')
                    
                    self.ckp.write_log(
                        '[{} x{}]\tepoch:{}\t PSNR : {:.3f} (Best: {:.3f} @epoch {})'.format(
                            d.dataset.name,
                            scale,
                            self.last_epoch + 1,
                            self.ckp.log[-1, idx_data, idx_scale], ### tmp epoch acc
                            best[0][idx_data, idx_scale], ### best epoch acc
                            best[1][idx_data, idx_scale] + 1 ### epoch idx
                        )
                    )
                    
                    ### save checkpoint every args.save_every epoch
                    # if epoch % self.args.save_every == 0:
                    #     print("Saving checkpoint of epoch {}...".format(self.last_epoch + 1))
                    #     tmp_save_path = self.args.save + 'epoch_{}.pt'.format(self.last_epoch + 1)
                    #     torch.save(self.model.state_dict(), tmp_save_path)
                    ### save checkpoint if tmp epoch obtains the best acc.
                    # if self.ckp.log[-1, idx_data, idx_scale] >= best[0][idx_data, idx_scale]:
                    #     print("Saving best epoch at epoch {}...".format(self.last_epoch + 1))
                    #     tmp_save_path = self.args.save + 'best_epoch.pt'
                    #     torch.save(self.model.state_dict(), tmp_save_path)
                    
        # import pdb; pdb.set_trace()
        self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
        self.ckp.write_log('Saving...')

        if self.args.save_results:
            self.ckp.end_background()

        self.ckp.write_log(
            'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
        )

        torch.set_grad_enabled(True)

    def prepare(self, *args):
        device = torch.device('cpu' if self.args.cpu else 'cuda')
        def _prepare(tensor):
            if self.args.precision == 'half': tensor = tensor.half()
            return tensor.to(device)

        return [_prepare(a) for a in args]

    def terminate(self):
        if self.args.test_only:
            # import pdb; pdb.set_trace()
            self.test()
            return True
        else:
            epoch = self.optimizer.get_last_epoch() + 1
            return epoch >= self.args.epochs
    def _np2Tensor(self, img, rgb_range):
        np_transpose = np.ascontiguousarray(img.transpose((2, 1, 2)))
        tensor = np_transpose.astype(np.float32)
        tensor = tensor * (rgb_range / 255)
        return tensor
