import copy
from random import Random

from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
import torch
import torchvision.transforms.v2 as tv
from tqdm.auto import tqdm

from diagram import System
from models import (AutoEncoder, LinearGetter,
                    LinearPutter, LinearPutterWithComplement)
from rules.autoencoder import autoencode
from rules.lens import (classify, get_put_teacher_forced, put_get, put_put,
                        undo_teacher_forced)
from spriteworld import Color, Shape, SpriteWorld

# CONSTANTS
# Image parameters
IMAGE_SIZE = 32
IMAGE_CHANNELS = 3
IMAGE_SHAPE = (IMAGE_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

# Architecture parameters
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CONCEPT_SIZE = 3
LATENT_SIZE = 32
COMPLEMENT_SIZE = 8
CNN_HIDDEN_LAYERS = 4
CNN_HIDDEN_CHANNELS = 64
CNN_CHANNEL_MULTIPLIER = 1

# Training hyper-parameters
BATCH_SIZE = 512
CLIP_GRADIENTS = 1
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-2
STEPS = 100000
LOGGING_STEP = 1000
SEED = 0

pil_to_tensor = tv.Compose([
    tv.PILToTensor(),
    tv.ToDtype(torch.float32, scale=True)
])


class ImagesDataset(torch.utils.data.IterableDataset):
    def __init__(self, spriteworld, batch_size: int) -> None:
        self.spriteworld = spriteworld
        self.batch_size = batch_size
        self.rng = None

    def __iter__(self):
        self.rng = Random(torch.initial_seed())
        while True:
            images = []
            attributes = []
            labels = []
            for _ in range(self.batch_size):
                sample = self.spriteworld.sample(self.rng)

                images.append(pil_to_tensor(sample['image'].convert('RGB')))
                attributes.append([list(type(x)).index(x)
                                   for x in sample['attributes']])
                labels.append(get_label(sample['properties'][0],
                                        sample['properties'][1][0]))

            yield {
                'image': torch.stack(images),
                'attributes': torch.as_tensor(attributes),
                'random_attribute': torch.randint(high=CONCEPT_SIZE,
                                                  size=(self.batch_size,)),
                'random_attribute_2': torch.randint(high=CONCEPT_SIZE,
                                                    size=(self.batch_size,)),
                'label': torch.as_tensor(labels),
                'random_label': torch.empty(self.batch_size).uniform_(-0.1, 1.1).clamp(0, 1),
                'random_label_2': torch.empty(self.batch_size).uniform_(-0.1, 1.1).clamp(0, 1)
            }


def get_label(shape, color):
    if color < Color.BLUE:
        c = color / Color.BLUE
    else:
        c = 1 - (color - Color.BLUE) / (256 - Color.BLUE)

    if shape == Shape.ELLIPSE:
        return min(1, c + 0.6)
    elif shape == Shape.RECTANGLE:
        return max(0.2, min(0.8, c))
    elif shape == Shape.TRIANGLE:
        return max(0, c - 0.6)


def image_loss(pred, target):
    return (torch.nn.MSELoss()(pred, target)
            + 0.25 * torch.nn.L1Loss()(pred, target))


def ce_or_mse_loss(pred, target):
    if target.is_floating_point():
        return torch.nn.MSELoss()(pred, target)
    else:
        return torch.nn.CrossEntropyLoss()(pred, target)


autoencoder_task = System()
autoencoder_task.add_rule('autoencode',
                          autoencode,
                          image_loss,
                          autoencoder='autoencoder',
                          data='data')

manipulator_task = System()
manipulator_task.add_rule('classify',
                          classify,
                          ce_or_mse_loss,
                          getter='getter',
                          state='state',
                          value='value')
manipulator_task.add_rule('get_put',
                          get_put_teacher_forced,
                          image_loss,
                          putter='putter',
                          state='state',
                          value='value')
manipulator_task.add_rule('put_get',
                          put_get,
                          ce_or_mse_loss,
                          putter='putter',
                          getter='getter_clone',
                          state='state',
                          value='random_value')
manipulator_task.add_rule('put_put',
                          put_put,
                          image_loss,
                          putter='putter',
                          state='state',
                          value1='random_value',
                          value2='random_value_2')
manipulator_task.add_rule('undo',
                          undo_teacher_forced,
                          image_loss,
                          putter='putter',
                          state='state',
                          value='value',
                          random_value='random_value')

rng = Random(SEED)
torch.manual_seed(SEED)

spriteworld = SpriteWorld(32, 1, (0, 0, 0), 8, (64, 255), (64, 255), 3, 3, 3, 3)
dataset = ImagesDataset(spriteworld, BATCH_SIZE)

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=None,
                                         num_workers=1,
                                         prefetch_factor=2)

autoencoder = AutoEncoder(IMAGE_SIZE,
                          LATENT_SIZE,
                          CNN_HIDDEN_LAYERS,
                          CNN_HIDDEN_CHANNELS,
                          CNN_CHANNEL_MULTIPLIER).to(DEVICE).train()

getter = LinearGetter(autoencoder, 1).to(DEVICE)
putter = LinearPutterWithComplement(autoencoder, 1, COMPLEMENT_SIZE).to(DEVICE)

shape_getter = LinearGetter(autoencoder, CONCEPT_SIZE).to(DEVICE)
shape_putter = LinearPutter(autoencoder, CONCEPT_SIZE, True).to(DEVICE)

color_getter = LinearGetter(autoencoder, CONCEPT_SIZE).to(DEVICE).train()
color_putter = LinearPutter(autoencoder, CONCEPT_SIZE, True).to(DEVICE)

models = torch.nn.ModuleList([autoencoder,
                              getter, putter,
                              shape_getter, shape_putter,
                              color_getter, color_putter])
all_parameters = {parameter
                  for model in models
                  for parameter in model.parameters()
                  if parameter.requires_grad}
optimiser = torch.optim.AdamW(all_parameters,
                              lr=LEARNING_RATE,
                              weight_decay=WEIGHT_DECAY)
print(f'Number of parameters = {sum(p.numel() for p in all_parameters)}')

shape_getter_clone = copy.deepcopy(shape_getter).eval().requires_grad_(False)
color_getter_clone = copy.deepcopy(color_getter).eval().requires_grad_(False)
getter_clone = copy.deepcopy(getter).eval().requires_grad_(False)

step = 0
with tqdm(total=STEPS) as progress_bar:
    for batch in dataloader:
        state_batch = batch['image'].to(DEVICE)

        shape_batch = batch['attributes'][:, 0].to(DEVICE)
        color_batch = batch['attributes'][:, 1].to(DEVICE)

        random_attribute = batch['random_attribute'].to(DEVICE)
        random_attribute_2 = batch['random_attribute_2'].to(DEVICE)

        value_batch = batch['label'].unsqueeze(1).to(DEVICE)
        random_value = batch['random_label'].unsqueeze(1).to(DEVICE)
        random_value_2 = batch['random_label_2'].unsqueeze(1).to(DEVICE)

        autoencoder_results = autoencoder_task(autoencoder=autoencoder,
                                               data=state_batch)

        shape_results = manipulator_task(putter=shape_putter,
                                         getter=shape_getter,
                                         getter_clone=shape_getter_clone,
                                         state=state_batch,
                                         value=shape_batch,
                                         random_value=random_attribute,
                                         random_value_2=random_attribute_2)
        color_results = manipulator_task(putter=color_putter,
                                         getter=color_getter,
                                         getter_clone=color_getter_clone,
                                         state=state_batch,
                                         value=color_batch,
                                         random_value=random_attribute,
                                         random_value_2=random_attribute_2)
        blue_circle_results = manipulator_task(putter=putter,
                                               getter=getter,
                                               getter_clone=getter_clone,
                                               state=state_batch,
                                               value=value_batch,
                                               random_value=random_value,
                                               random_value_2=random_value_2)
        results = {f'{name}.{k}': v
                   for name, result in (('autoencoder', autoencoder_results),
                                        ('shape', shape_results),
                                        ('color', color_results),
                                        ('blue_circle', blue_circle_results))
                   for k, v in result.items()}
        losses = {k: v['loss'] for k, v in results.items()}

        loss_weights = {
            'autoencoder.autoencode': 100,

            'shape.classify': 1,
            'shape.get_put': 100,
            'shape.put_get': 1,
            'shape.put_put': 1,
            'shape.undo': 10,

            'color.classify': 1,
            'color.get_put': 100,
            'color.put_get': 1,
            'color.put_put': 1,
            'color.undo': 10,

            'blue_circle.classify': 10,
            'blue_circle.get_put': 100,
            'blue_circle.put_get': 10,
            'blue_circle.put_put': 1,
            'blue_circle.undo': 10,
        }

        loss = sum(losses[k] * v for k, v in loss_weights.items())

        optimiser.zero_grad()
        loss.backward()
        if CLIP_GRADIENTS:
            torch.nn.utils.clip_grad_value_(models.parameters(),
                                            CLIP_GRADIENTS)
        optimiser.step()

        progress_bar.update(1)
        step += 1

        putter.sync_autoencoder(autoencoder)
        shape_putter.sync_autoencoder(autoencoder)
        color_putter.sync_autoencoder(autoencoder)

        getter_clone.load_state_dict(getter.state_dict())
        shape_getter_clone.load_state_dict(shape_getter.state_dict())
        color_getter_clone.load_state_dict(color_getter.state_dict())

        if step % LOGGING_STEP == 0:
            print(f'Ep {step}\nLoss: {loss:g} <-',
                  ' '.join(f'{k}={v:g}' for k, v in losses.items()))

        if step >= STEPS:
            break

test_image = pil_to_tensor(
    spriteworld.draw(
        Shape.RECTANGLE, (40, 250, 200), 9, 12, 10, 20
    ).convert('RGB')
)
images = test_image.unsqueeze(0).expand(11, -1, -1, -1).to(DEVICE)

v = torch.arange(0, 1.1, 0.2).unsqueeze(1).to(DEVICE)
attribute_values = torch.cat((torch.cat((v*0, v, 1-v), dim=-1),
                              torch.cat((v, 1-v, v*0), dim=-1)[1:]),
                             dim=0)
blue_circle_values = torch.arange(0, 1.1, 0.1).unsqueeze(1).to(DEVICE)

with torch.no_grad():
    autoencoder.eval()
    putter.eval()

    shape_putter.eval()
    shape_putter.one_hot_concept = False

    color_putter.eval()
    color_putter.one_hot_concept = False

    outputs = [shape_putter(attribute_values, images)[0],
               color_putter(reversed(attribute_values), images)[0],
               putter(blue_circle_values, images)[0]]

figure = plt.figure(layout='tight', figsize=(16, 4))
gs = GridSpec(3, 2, figure=figure, width_ratios=(2, 10), wspace=0.01, hspace=0.3)

original_image_axis = figure.add_subplot(gs[:2, 0])
original_image_axis.axis('off')
original_image_axis.imshow(test_image.permute(1, 2, 0))
original_image_axis.text(15, 33, 'Original', ha='center', va='top')

axes = [figure.add_subplot(gs[0, 1]),
        figure.add_subplot(gs[1, 1]),
        figure.add_subplot(gs[2, 1])]
rows = [[image.cpu().permute(1, 2, 0) for image in row] for row in outputs]
row_labels = [('Triangle', 'Rectangle', 'Ellipse'),
              ('Red', 'Green', 'Blue'),
              ('0', '0.5', '1')]
for ax, row, xlabels in zip(axes, rows, row_labels):
    margin = 2
    size = IMAGE_SIZE + margin
    ax.get_yaxis().set_visible(False)
    ax.set_ylim(0, size)
    ax.set_xlim(0, 11*size - margin)
    ax.set_xticks([IMAGE_SIZE // 2,
                   IMAGE_SIZE // 2 + 5*size,
                   IMAGE_SIZE // 2 + 10*size],
                  xlabels)
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    for i, image in enumerate(row):
        ax.imshow(image, extent=(i*size, i*size + IMAGE_SIZE, margin, size))

OUTPUT_FILE = 'shape-output.png'
figure.savefig(OUTPUT_FILE, bbox_inches='tight', dpi=300)
print(f'Saved figure to {repr(OUTPUT_FILE)}.')
