#!/usr/bin/env python

import argparse
import os
from functools import partial, reduce

output_home = os.path.join('output', 'shape')

#
# Simple parser to get custom train/valid/test splits.
#
class SplitsParser(argparse.Action):

     def __init__(self, option_strings, dest, nargs=None, **kwargs):
         if nargs is not None:
             raise ValueError("nargs not allowed")
         super().__init__(option_strings, dest, **kwargs)

     def __call__(self, parser, namespace, values, option_string=None):
         try:
            _values = tuple(map(lambda s: int(s.strip()), values.split(',')))
            assert len(_values) == 3
         except:
             raise parser.error("Invalid splits format: Need a list of 3 integers separated with commas")
         try:
            assert reduce(lambda x,y: x+y, _values) == 100
         except:
             raise parser.error("The splits must add up to 100%")
         setattr(namespace, self.dest, _values)


from datetime import datetime, timezone
from hashlib import sha1
import logging
from pathlib import Path

import av
import numpy
from PIL import Image
import torch
import yaml

from classifier import ShapeClassifier

logging.basicConfig(level=logging.INFO)

#
# Helpers
#
from dataclasses import dataclass
@dataclass
class RuntimeConfig:
    mode: str
    model: torch.nn.Module
    loss_fn: torch.nn.Module
    epoch: int = 0
    batch: int = 0



def last_snapshot_for(base=output_home):
    last = None
    last_ts = None
    for r, ds, fs in os.walk(base):
        if last_ts:
            last = sorted([f for f in fs if f.endswith('.pth')])
            if len(last) > 0:
                last = last[-1]
                last = os.path.join(r, last)
                break
        else:
            last_ts = sorted(ds)[-1]
    return last


def config_hash(config:dict) -> str:
    return sha1(str(config).encode('ascii')).hexdigest()


def apply(rc, optimizer, expected, inputs, test=False):
    if optimizer:
        optimizer.zero_grad()
    outputs = rc.model(inputs)

    loss = rc.loss_fn(outputs, expected)
    if optimizer:
        loss.backward(retain_graph=True)
        optimizer.step()

    if test:
        matches = 0
        for idx in range(len(expected)):
            if expected[idx].item() == torch.argmax(outputs[idx]).item():
                matches += 1
        return (loss, matches)
    else:
        return loss


def train(config):
    if not args.seed is None:
        torch.manual_seed(args.seed)
    logging.info(f'Torch seed: {torch.initial_seed()}')

    common_dtype = torch.float32

    device = config.device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data loading
    from loaders.shapes import ShapesDataset
    dataset = ShapesDataset(config.dataset, return_stills=True)

    tn = int(len(dataset) * args.splits[0] / 100)
    vn = int(len(dataset) * args.splits[1] / 100)
    train_ds, valid_ds, test_ds = torch.utils.data.random_split(dataset, (tn, vn, len(dataset) - tn - vn))

    from torch.utils.data import DataLoader
    batch_size = 30 # Align batch size with clip size (default 30), as the model works on stills. TODO parameterize.
    dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False)
    validation_dataloader = DataLoader(valid_ds, batch_size=batch_size, num_workers=0, drop_last=False)
    testing_dataloader = DataLoader(test_ds, batch_size=batch_size, num_workers=0, drop_last=False)

    # Target model
    model = ShapeClassifier(dataset.class_count())
    model.to(device)

    # Optimizer
    optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=0.001, betas=(0.9, 0.999),
            eps=1e-08, weight_decay=0.01,
            amsgrad=False)

    # Model and optimizer initialization
    save_dir = os.path.join(output_home, str(int(datetime.now(timezone.utc).timestamp())))
    state = {}
    initial_epoch = 1
    logging.info(f"Training from scratch, snapshotting to {save_dir}")

    Path(save_dir).mkdir(parents=True, exist_ok=True)

    with open(os.path.join(save_dir, 'seed.txt'), 'w') as f:
        f.write(str(torch.initial_seed()))

    # Training metric
    criterion = torch.nn.CrossEntropyLoss()

    from tqdm import tqdm
    save_every = 1000 # batches
    logging.info(f'Snapshots saved every {save_every} batches')

    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir=os.path.join(output_home, 'logs', os.path.basename(save_dir)))

    rc = RuntimeConfig(
        mode='init',
        model=model,
        loss_fn=criterion,
    )

    validation_history = []

    last_epoch = initial_epoch + args.epochs - 1
    for epoch in range(initial_epoch, last_epoch+1):
        running_loss = 0.0
        rc.mode = 'training'
        rc.epoch = epoch
        with tqdm(total=len(dataloader)) as pbar:
            pbar.set_description(f'Epoch {epoch}/{last_epoch}')
            for idx, batch in enumerate(dataloader):
                rc.batch = idx

                videos = batch['image'].type(common_dtype).to(device)
                expected = batch['label'].to(device)

                _loss = apply(rc, optimizer, expected, videos)

                running_loss = (_loss + running_loss * idx) / (idx + 1)

                if idx % save_every == (save_every-1):
                    state = {
                            'epoch': epoch,
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            }
                    batch_block = str(idx+1).zfill(6)
                    epoch_block = str(epoch).zfill(3)
                    torch.save(state, os.path.join(save_dir, f'shape_classifier_{epoch_block}_{batch_block}.pth'))

                pbar.set_postfix(loss=running_loss.item())
                pbar.update(1)

                if idx % 100 == 99 or idx == len(validation_dataloader) - 1:
                    writer.add_scalar('loss/train', running_loss, len(dataloader)*(epoch-1) + idx)

        state = {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                }
        epoch_block = str(epoch).zfill(3)
        torch.save(state, os.path.join(save_dir, f'shape_classifier_{epoch_block}_complete.pth'))

        with torch.no_grad():
            try:
                rc.mode = 'validation'
                validation_loss = 0.0
                with tqdm(total=len(validation_dataloader)) as pbar:
                    pbar.set_description(f'Validation on epoch {epoch}')
                    for idx, batch in enumerate(validation_dataloader):
                        rc.batch = idx

                        videos = batch['image'].type(common_dtype).to(device)
                        expected = batch['label'].to(device)
                        _loss = apply(rc, None, expected, videos)

                        validation_loss = (_loss + validation_loss * idx) / (idx + 1)

                        pbar.set_postfix(loss=validation_loss.item())
                        pbar.update(1)

                        if idx % 100 == 99 or idx == len(validation_dataloader) - 1:
                            writer.add_scalar('loss/valid', validation_loss, len(validation_dataloader)*(epoch-1) + idx)

                validation_history.append(validation_loss)
                if len(validation_history) > 3 and reduce(lambda x,y: x and y, [x <= validation_loss for x in validation_history[-3:]]):
                    raise StopIteration()
            except StopIteration:
                logging.info(f"Early stopping after validation loss not decreasing for 3 epochs.")
                break

    with torch.no_grad():
        rc.mode = 'testing'
        testing_loss = 0.0
        accuracy_results = 0
        accuracy_frames = 0
        artifacts_path = os.path.join(output_home, 'test_samples')
        Path(artifacts_path).mkdir(parents=True, exist_ok=True)
        with tqdm(total=len(testing_dataloader)) as pbar:
            pbar.set_description(f'Testing best model')
            for idx, batch in enumerate(testing_dataloader):
                rc.batch = idx

                videos = batch['image'].type(common_dtype).to(device)
                expected = batch['label'].to(device)

                _loss, matches = apply(rc, None, expected, videos, test=True)

                testing_loss = (_loss + testing_loss * idx) / (idx + 1)

                accuracy_frames += len(videos)
                accuracy_results += matches

                pbar.set_postfix(loss=testing_loss.item())
                pbar.update(1)

                if idx % 10 == 9 or idx == len(testing_dataloader) - 1:
                    writer.add_scalar('loss/test', testing_loss, idx)
                    [
                        Image.fromarray(videos[fidx].numpy().astype(numpy.uint8).transpose((1,2,0)), mode='RGB').save(os.path.join(artifacts_path, f"{idx}_{fidx}.jpg"))
                        for fidx in range(len(videos))
                    ]

                _accuracy = (accuracy_results / accuracy_frames) * 100.0
                writer.add_scalar('acc/plain', _accuracy, idx)
                writer.add_scalar('acc/batches', accuracy_frames, idx)
                writer.add_scalar('acc/matches', accuracy_results, idx)
            print(f'Final accuracy on test set: {_accuracy}')

def infer(dataset, image):
    common_dtype = torch.float32
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    from loaders.shapes import ShapesDataset
    dataset = ShapesDataset(dataset, return_stills=True)

    with torch.no_grad():
        model = ShapeClassifier(dataset.class_count())

        model_file = last_snapshot_for(output_home)
        print(f"Shape classifier will load from {model_file}")
        state = torch.load(model_file, map_location=device)
        model.load_state_dict(state['model'])
        model.eval()
        model.to(device)

        img = numpy.array(Image.open(image)).transpose((2,0,1))
        imput = torch.Tensor(img).unsqueeze(0)

        predictions = model(imput.to(device))
        class_idx = torch.argmax(predictions).item()

    return dataset.label_for(class_idx)

if __name__ == '__main__':
    parser = argparse.ArgumentParser("Shape runner")
    # Tried and did not success (yet) with subparsers.
    parser.add_argument("command",
            choices=["train", "infer"],
            help=f"Choose a run mode.")
    parser.add_argument("--dataset",
            required=False,
            default='./shape_dataset.pickle',
            help="Dataset archive, as generated by evaluation/bin/gen_shape_dataset.")
    parser.add_argument("--image",
            required=False,
            default=None,
            help="Image to run inference on. Ignored if command is train, required if command is infer.")
    parser.add_argument("--epochs",
            required=False,
            type=int,
            default=10,
            help="Number of epochs to train for.")
    parser.add_argument("--splits",
            required=False,
            action=SplitsParser,
            default=(85,5,10),
            help="Data point splits training/validation/testing. Data points are slices of AV files, not files themselves. Defaults to (85, 5, 10)%% for each. Numbers must add to 100.")
    parser.add_argument("--seed",
            required=False,
            type=int,
            default=None,
            help="Seed for PyTorch and Numpy. It will affect the data split allocations.")
    parser.add_argument("--device",
            choices=["cpu", "cuda"],
            required=False,
            default=None,
            help="Force a device, either cpu or CUDA.")
    args = parser.parse_args()


    if args.command == 'train':
        train(args)
    elif args.command == 'infer':
        if args.image:
            print(infer(args.dataset, args.image))
        else:
            raise ValueError('Missing image file to run inference on.')
