import copy
from random import Random

import datasets
from pytorch_msssim import SSIM
import torch
from torch.utils.data import DataLoader
import torchvision.transforms.v2 as tv
from tqdm.auto import tqdm

from diagram import System
from models import AutoEncoder, LinearGetter, LinearPutter
from rules.autoencoder import autoencode
from rules.lens import (classify, get_put_teacher_forced, put_get, put_put,
                        undo_teacher_forced)

# CONSTANTS
# Image parameters
IMAGE_SIZE = 128
IMAGE_CHANNELS = 3
NUM_WORKERS = 4

# Architecture parameters
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CONCEPT = 'Smiling'
CONCEPT_SIZE = 2
LATENT_SIZE = 128
CNN_HIDDEN_LAYERS = 5
CNN_HIDDEN_CHANNELS = 8
CNN_CHANNEL_MULTIPLIER = 2

# Training hyper-parameters
BATCH_SIZE = 64
CLIP_GRADIENTS = 1
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-2
STEPS = 100000
LOGGING_STEP = 1000
SEED = 0

raw_dataset = datasets.load_dataset('tpremoli/CelebA-attrs')
features = {k: i for i, k in enumerate(
    k for k, v in raw_dataset['train'].features.items() if v.dtype == 'int64'
)}

augment = tv.Compose([
    tv.TrivialAugmentWide(),
    tv.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    tv.PILToTensor(),
    tv.ToDtype(torch.float32, scale=True)
])


def transform(examples: dict) -> dict:
    output = {}
    output['label'] = label = [[0] * len(features) for _ in examples['image']]
    for k, v in examples.items():
        if k == 'image':
            output['image'] = torch.stack([augment(image) for image in v])
        elif k in features:
            i = features[k]
            for j, f in enumerate(v):
                label[j][i] = int(f == 1)
    return {k: torch.as_tensor(v) for k, v in output.items()}


ssim = SSIM(data_range=1, size_average=True, channel=IMAGE_CHANNELS)


def image_loss(pred, target):
    return (torch.nn.MSELoss()(pred, target)
            + 0.25 * torch.nn.L1Loss()(pred, target)
            + (1 - ssim(pred, target)))


manipulator_task = System()
manipulator_task.add_rule('autoencode',
                          autoencode,
                          image_loss,
                          autoencoder='autoencoder',
                          data='state')
manipulator_task.add_rule('classify',
                          classify,
                          torch.nn.CrossEntropyLoss(),
                          getter='getter',
                          state='state',
                          value='value')
manipulator_task.add_rule('get_put',
                          get_put_teacher_forced,
                          image_loss,
                          putter='putter',
                          state='state',
                          value='value')
manipulator_task.add_rule('put_get',
                          put_get,
                          torch.nn.CrossEntropyLoss(),
                          putter='putter',
                          getter='getter_clone',
                          state='state',
                          value='random_value')
manipulator_task.add_rule('put_put',
                          put_put,
                          image_loss,
                          putter='putter',
                          state='state',
                          value1='random_value',
                          value2='random_value_2')
manipulator_task.add_rule('undo',
                          undo_teacher_forced,
                          image_loss,
                          putter='putter',
                          state='state',
                          value='value',
                          random_value='random_value')

rng = Random(SEED)
torch.manual_seed(SEED)

train_dataset = raw_dataset['train'].with_transform(transform)
train_dataloader = DataLoader(train_dataset,
                              batch_size=BATCH_SIZE,
                              drop_last=True,
                              shuffle=True,
                              num_workers=NUM_WORKERS,
                              prefetch_factor=1)

autoencoder = AutoEncoder(IMAGE_SIZE,
                          LATENT_SIZE,
                          CNN_HIDDEN_LAYERS,
                          CNN_HIDDEN_CHANNELS,
                          CNN_CHANNEL_MULTIPLIER).to(DEVICE)
getter = LinearGetter(autoencoder, CONCEPT_SIZE).to(DEVICE)
putter = LinearPutter(autoencoder, CONCEPT_SIZE, True).to(DEVICE)
models = torch.nn.ModuleList([autoencoder, getter, putter])
all_parameters = {parameter
                  for model in models
                  for parameter in model.parameters()
                  if parameter.requires_grad}
optimiser = torch.optim.AdamW(all_parameters,
                              lr=LEARNING_RATE,
                              weight_decay=WEIGHT_DECAY)
print(f'Number of parameters = {sum(p.numel() for p in all_parameters)}')

getter_clone = copy.deepcopy(getter).eval().requires_grad_(False)

done = False
step = 0
with tqdm(total=STEPS) as progress_bar:
    while not done:
        for batch in train_dataloader:
            putter.decode.eval()

            state_batch = batch['image'].to(DEVICE)
            value_batch = batch['label'][..., features[CONCEPT]].to(DEVICE)

            random_value_batch = torch.randint(high=CONCEPT_SIZE,
                                               size=(BATCH_SIZE,)).to(DEVICE)
            random_value_batch_2 = torch.randint(high=CONCEPT_SIZE,
                                                 size=(BATCH_SIZE,)).to(DEVICE)

            results = manipulator_task(autoencoder=autoencoder,
                                       putter=putter,
                                       getter=getter,
                                       getter_clone=getter_clone,
                                       state=state_batch,
                                       value=value_batch,
                                       random_value=random_value_batch,
                                       random_value_2=random_value_batch_2)
            losses = {k: v['loss'] for k, v in results.items()}

            loss_weights = {
                'autoencode': 10,

                'classify': 1,
                'get_put': 1,
                'put_get': 1,
                'put_put': 1,
                'undo': 1,
            }
            loss = sum(losses[k] * v for k, v in loss_weights.items())

            optimiser.zero_grad()
            loss.backward()
            if CLIP_GRADIENTS:
                torch.nn.utils.clip_grad_value_(models.parameters(),
                                                CLIP_GRADIENTS)
            optimiser.step()

            progress_bar.update(1)
            step += 1

            putter.sync_autoencoder(autoencoder)

            getter_clone.load_state_dict(getter.state_dict())

            if step % LOGGING_STEP == 0:
                print(f'Ep {step}\nLoss: {loss:g} <-',
                      ' '.join(f'{k}={v:g}' for k, v in losses.items()))

            if step >= STEPS:
                done = True
                break

OUTPUT_FILE = 'models.pt'
torch.save({'autoencoder': autoencoder.state_dict(),
            'getter': getter.state_dict(),
            'putter': putter.state_dict()}, OUTPUT_FILE)
print(f'Saved models to {repr(OUTPUT_FILE)}.')
