from enum import auto
import logging
from collections import OrderedDict
from data.util import *
import torch
import torch.nn as nn
from torchvision import transforms
import os
import idm.networks as networks
from .base_model import BaseModel
import random
import data.util as Util
import numpy as np
from .sr3_modules.elan import ELAN
from .sr3_modules.esrt import ESRT
from .sr3_modules.vit import ViT

prior_model = {"esrt": ESRT, "elan": ELAN}

class DDPM(BaseModel):
    def __init__(self, opt):
        super(DDPM, self).__init__(opt)
        # define network and load pretrained models
        self.netG = self.set_device(networks.define_G(opt))
        self.schedule_phase = None
        self.device = torch.device(
            'cuda')
        # set loss and load resume state
        self.set_loss()
        self.set_new_noise_schedule(
            opt['model']['beta_schedule']['train'], schedule_phase='train')
        if self.opt['phase'] == 'train':
            self.netG.train()
            # find the parameters to optimize
            if opt['model']['finetune_norm']:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    v.requires_grad = False
                    if k.find('transformer') >= 0:
                        v.requires_grad = True
                        v.data.zero_()
                        optim_params.append(v)
                        logging.info(
                            'Params [{:s}] initialized to 0 and will optimize.'.format(k))
            else:
                optim_params = list(self.netG.parameters())

            self.optG = torch.optim.Adam(
                optim_params, lr=opt['train']["optimizer"]["lr"])
            self.log_dict = OrderedDict()
        self.sub, self.div = torch.FloatTensor([0.5]).view(1, -1, 1, 1), torch.FloatTensor([0.5]).view(1, -1, 1, 1)
        self.print_network()
        self.load_network()

    def feed_data(self, data):
        
        p = random.random()

        img_lr, img_hr, mask = data['lr'], data['gt'], data['mask']

        hr_coord, _ = Util.to_pixel_samples(img_hr)
        cell = torch.ones_like(hr_coord)
        cell[:, 0] *= 2 / img_hr.shape[-2]
        cell[:, 1] *= 2 / img_hr.shape[-1]
        hr_coord = hr_coord.repeat(img_hr.shape[0], 1, 1)
        cell = cell.repeat(img_hr.shape[0], 1, 1)
        
        data = {
        'inp': img_lr,
        'coord': hr_coord,
        'cell': cell,
        'mask': mask,
        'gt': img_hr,
        'scaler': torch.from_numpy(np.array([p], dtype=np.float32)) } 

        self.data = self.set_device(data)

    def optimize_parameters(self, scaler=0):
        self.optG.zero_grad()


        l_pix = self.netG(self.data)
        # need to average in multi-gpu
        b, c, h, w = self.data['gt'].shape
        l_pix = l_pix.sum()/int(b*c*h*w)

        l_pix.backward()
        self.optG.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self, continous=False):
        self.netG.eval()
        with torch.no_grad():
            if isinstance(self.netG, nn.parallel.DistributedDataParallel):
                self.SR = self.netG.module.super_resolution(
                    self.data, continous)
            else:
                self.SR = self.netG.super_resolution(
                    self.data, continous)


        self.netG.train()

    def sample(self, batch_size=1, continous=False):
        self.netG.eval()
        with torch.no_grad():
            if isinstance(self.netG, nn.parallel.DistributedDataParallel):
                self.SR = self.netG.module.sample(batch_size, continous)
            else:
                self.SR = self.netG.sample(batch_size, continous)
        self.netG.train()

    def set_loss(self):
        if isinstance(self.netG, nn.parallel.DistributedDataParallel):
            self.netG.module.set_loss(self.device)
        else:
            self.netG.set_loss(self.device)

    def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'):
        if self.schedule_phase is None or self.schedule_phase != schedule_phase:
            self.schedule_phase = schedule_phase
            if isinstance(self.netG, nn.parallel.DistributedDataParallel):
                self.netG.module.set_new_noise_schedule(
                    schedule_opt, self.device)
            else:
                self.netG.set_new_noise_schedule(schedule_opt, self.device)

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_LR=True, sample=False):
        out_dict = OrderedDict()
        if sample:
            out_dict['SAM'] = self.SR.detach().float().cpu()
        else:
            out_dict['SR'] = self.SR.detach().float().cpu()
            out_dict['INF'] = self.data['inp'].detach().float().cpu()
            out_dict['HR'] = self.data['gt'].detach().float().cpu()
            if need_LR and 'LR' in self.data:
                out_dict['LR'] = self.data['inp'].detach().float().cpu()
            else:
                out_dict['LR'] = out_dict['INF']
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.parallel.DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logging.info(
            'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
        # logging.info(s)

    def save_network(self, epoch="", path=""):
        
        gen_path = path
        network = self.netG
        if isinstance(self.netG, nn.parallel.DistributedDataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, gen_path)
        # opt
        logging.info('Saved model in [{:s}] ...'.format(gen_path))

    def load_network(self):
        if self.opt['train']['resume'] > 0:
            logging.info(
                'Loading pretrained model for G [{}] ...'.format(self.opt['train']['resume']))
            gen_path = '{}/models/idm_latest.pth'.format(self.opt['out_path'])
            # opt_path = '{}_opt.pth'.format(load_path)
            if not os.path.isfile(gen_path):
                return
            # gen
            network = self.netG
            if isinstance(self.netG, nn.parallel.DistributedDataParallel):
                network = network.module
            network.load_state_dict(torch.load(
                gen_path, map_location=torch.device('cpu')), strict=True)
            print("load success")
            # if self.opt['phase'] == 'train':
            #     # optimizer
            #     opt = torch.load(opt_path, map_location=torch.device('cpu'))

            #     self.begin_step = opt['iter']
            #     self.begin_epoch = opt['epoch']

# ViTSR -------------------------------------------------------------
class GuideTSR(nn.Module):
    def  __init__(self, cfg={}):
        super().__init__()
        self.stg1 = prior_model[cfg["prior_type"]](**cfg[cfg["prior_type"]])
        self.stg2 = ViT(**cfg["model"]["config"])
        
    def forward(self, x):
        out_1 = self.stg1(x)
        out_2 = self.stg2(out_1)
        if self.training:
            return out_1, out_2
        return out_2

# image to image vit autoencoder ------------------------------------
class I2IViT(nn.Module):
    def  __init__(self, cfg={}):
        super().__init__()
        self.vit = ViT(**cfg["model"]["config"])
        
    def forward(self, x):
        out = self.vit(x)
        
        return out
    
# discriminator -----------------------------------------------------
class NLayerDiscriminator(nn.Module):
    def __init__(self, in_channels=3, img_size=256, nf=32) -> None:
        super().__init__()
        self.model = nn.Sequential()
        for i in range(int(np.log2(img_size//8))):
            self.model.append(nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels if i == 0 else nf*i, nf*(i+1), kernel_size=5, stride=2, padding=2)),
                                nn.LeakyReLU(0.2)))
        self.model.append(nn.Conv2d(nf*(i+1), nf*(i+1), kernel_size=3, stride=1, padding=1))
        
    def forward(self, x, cond=None):
        if cond != None:
            x = torch.cat([x, cond], dim=1)
        out = self.model(x)
        return out