import math
import torch
import torch.nn as nn
import pytorch_lightning as pl
from model.experiment import DiffusionExperiment, add_exp_args
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms


class Experiment(DiffusionExperiment):

    def train_fn(self, epoch):
        self.model.train()
        loss_sum = 0.0
        loss_count = 0
        loss_moving = None
        for iteration, x in enumerate(self.train_loader):
            # x, length = x.to(self.args.device), length.to(self.args.device)
            # x = x.to(self.args.device)
            assert x['image'].shape[0] == len(x['text'])
            num_elem = (128+1024) * x['image'].shape[0]
            
            loss = - self.model(x, train=True).sum() / (math.log(2) * num_elem)
            loss.backward()
            if (iteration + 1) % self.args.update_freq == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                if self.scheduler_iter: self.scheduler_iter.step()
            loss_sum += loss.detach().cpu().item() * len(x)
            loss_count += len(x)

            if (iteration + 1) % 100 == 0:
                self.checkpoint_save()
            
            if (iteration + 1) % 1000 == 0:
                self.eval_stepfn(iteration)

            if loss_moving is None:
                loss_moving = loss.detach().cpu().item()
            else:
                loss_moving = .99 * loss_moving + .01 * loss.detach().cpu().item()

            if self.args.debug and loss_count > self.args.debug:
                break
            print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/char: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_moving), end='\r')
        print('')
        if self.scheduler_epoch: self.scheduler_epoch.step()
        return {'bpc': loss_sum/loss_count}

    def eval_fn(self, epoch):
        self.model.eval()

        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            for x in self.eval_loader:
                # x =  x.to(self.args.device)
                # num_elem = 128 * x.shape[0]
                num_elem = (128+1024) * x['image'].shape[0]
                loss = - self.model(x).sum() / (math.log(2) * num_elem)
                loss_sum += loss.detach().cpu().item() * len(x)
                loss_count += len(x)
                print('Evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/char: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\r')
            print('')
        return {'bpc': loss_sum/loss_count}
    
    def eval_stepfn(self, steps):
        self.model.eval()

        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            for x in self.eval_loader:
                # x =  x.to(self.args.device)
                # num_elem = 128 * x.shape[0]
                num_elem = (128+1024) * x['image'].shape[0]
                loss = - self.model(x).sum() / (math.log(2) * num_elem)
                loss_sum += loss.detach().cpu().item() * len(x)
                loss_count += len(x)
                print('Evaluating. Step: {}/{}, Datapoint: {}/{}, Bits/char: {:.3f}'.format(steps+1, (len(self.eval_loader)//x['image'].shape[0])+1,
                                     loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\r')
            print('')


def get_text(tokens,tkizer):
    texts = []
    for i in tokens:
        text = tkizer.decode(i)
        texts.append(text)
    return texts
class ExperimentPL(pl.LightningModule):
    
    def __init__(self, model, optimizer, scheduler, cond=None):
        super().__init__()
        self.model = model
        self.optimizer = optimizer 
        self.scheduler = scheduler
        self.cond = cond



    def forward(self, x, train=False):
        return self.model(x, train=train)

    def training_step(self, batch, batch_idx):
        assert batch['image'].shape[0] == len(batch['text'])
        num_elem = (128+1024) * batch['image'].shape[0] #none
        # num_elem = (64+1024) * batch['image'].shape[0] #none
        # num_elem = 1024 * batch['image'].shape[0] #txt
        # loss = - self.model(batch, train=True, cond='txt_full').sum() / num_elem
        loss = - self.model(batch, train=True, cond=self.cond).sum() / num_elem
        # loss = - self.model(batch, train=True).mean()
        self.log("train/diffusion", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True,sync_dist=True)
        # if batch_idx % 2000 == 0:
        #     batch_size_for_gen = batch['image'].shape[0] if batch['image'].shape[0] <=2 else 2
        #     self.model.eval()
        #     with torch.no_grad():
        #         samples = self.model.sample_chain_mask_fast(batch_size_for_gen)
        #         self.log_images(samples,  batch['image'], batch['text'])
        #     self.model.train()
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        assert batch['image'].shape[0] == len(batch['text'])
        # num_elem = 1024 * batch['image'].shape[0] #txt
        num_elem = (128+1024) * batch['image'].shape[0] #none
        # num_elem = (64+1024) * batch['image'].shape[0] #none
        loss = - self.model(batch, train=False, cond=self.cond).sum() / (math.log(2) * num_elem)
        self.log("eval/diffusion", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True,sync_dist=True)
        # if batch_idx % 2000 == 0:
        #     batch_size_for_gen = batch['image'].shape[0] if batch['image'].shape[0] <=2 else 2
        #     self.model.eval()
        #     with torch.no_grad():
        #         samples = self.model.sample_chain_mask_fast(batch_size_for_gen)
        #         self.log_images(samples,  batch['image'][:batch_size_for_gen], batch['text'][:batch_size_for_gen])
        return {'loss': loss}

    # @torch.no_grad()
    # def log_images(self, samples, raw_images, raw_texts):

    #     img_token = samples[0][:,:self.model.image_tokenizer.token_shape[0]**2]

    #     #avoid mask tokens
    #     img_token[img_token == self.model.img_classes+self.model.txt_classes] = 0

    #     text_token = samples[0][:,self.model.image_tokenizer.token_shape[0]**2:]
    #     gen_image = self.model.image_tokenizer.decode(img_token)

    #     gen_text = get_text(text_token-self.model.img_classes,self.model.text_tokenizer)
    #     raw_text = get_text(raw_texts,self.model.text_tokenizer)
    #     bg = Image.new("RGB", (1000, 100), (255, 255, 255))\
        
    #     both_text_img = ImageDraw.Draw(bg)
    #     font = ImageFont.truetype('/home/hu/UniDm/misc/CascadiaCode-Regular.otf', 14)
    #     both_text_img.text((10, 5), 'Sample_Text', font=font, fill="#000000")
    #     both_text_img.text((10, 20), gen_text[0], font=font, fill="#000000")
    #     both_text_img.text((10, 35), gen_text[1], font=font, fill="#000000")
    #     both_text_img.text((10, 55), 'GT_Text', font=font, fill="#000000")
    #     both_text_img.text((10, 70), raw_text[0], font=font, fill="#000000")
    #     both_text_img.text((10, 85), raw_text[1], font=font, fill="#000000")

    #     raw_im = raw_images.permute(0,3,1,2)/255.0 - 1
    #     txt_im = transforms.ToTensor()(bg)
    #     self.logger.log_image(key="inputs",images=[raw_im])
    #     self.logger.log_image(key="gen",images=[gen_image])
    #     self.logger.log_image(key="text",images=[txt_im])

        

    def configure_optimizers(self):
        optimizer = self.optimizer
        scheduler = self.scheduler
        return  {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "eval/diffusion_epoch"}
        # return  {"optimizer": optimizer, "monitor": "eval/diffusion_epoch"}
