#!/usr/bin/env python

import argparse
import os
from functools import partial, reduce

output_home = os.path.join('output', 'ec')

#
# Simple parser to get custom train/valid/test splits.
#
class SplitsParser(argparse.Action):

     def __init__(self, option_strings, dest, nargs=None, **kwargs):
         if nargs is not None:
             raise ValueError("nargs not allowed")
         super().__init__(option_strings, dest, **kwargs)

     def __call__(self, parser, namespace, values, option_string=None):
         try:
            _values = tuple(map(lambda s: int(s.strip()), values.split(',')))
            assert len(_values) == 3
         except:
             raise parser.error("Invalid splits format: Need a list of 3 integers separated with commas")
         try:
            assert reduce(lambda x,y: x+y, _values) == 100
         except:
             raise parser.error("The splits must add up to 100%")
         setattr(namespace, self.dest, _values)


from datetime import datetime, timezone
from hashlib import sha1
import logging
from pathlib import Path

import av
import numpy
from PIL import Image
import torch
import yaml

from classifier import EmotionClassifier

logging.basicConfig(level=logging.INFO)

#
# Helpers
#
from dataclasses import dataclass
@dataclass
class RuntimeConfig:
    mode: str
    model: torch.nn.Module
    loss_fn: torch.nn.Module
    epoch: int = 0
    batch: int = 0



def last_snapshot_for(base=output_home):
    last = None
    last_ts = None
    for r, ds, fs in os.walk(base):
        if last_ts:
            last = sorted([f for f in fs if f.endswith('.pth')])
            if len(last) > 0:
                last = last[-1]
                last = os.path.join(r, last)
                break
        else:
            last_ts = sorted(ds)[-1]
    return last


def config_hash(config:dict) -> str:
    return sha1(str(config).encode('ascii')).hexdigest()


def apply(rc, optimizer, expected, inputs, test=False):
    if optimizer:
        optimizer.zero_grad()
    outputs = rc.model(inputs / 255.0) # Normalize to [0, 1]

    loss = rc.loss_fn(outputs, expected)
    if optimizer:
        loss.backward(retain_graph=True)
        optimizer.step()

    if test:
        matches = 0
        for idx in range(len(expected)):
            if expected[idx].item() == torch.argmax(outputs[idx]).item():
                matches += 1
        return (loss.item(), matches)
    else:
        return loss.item()


def train(config):
    if not args.seed is None:
        torch.manual_seed(args.seed)
    logging.info(f'Torch seed: {torch.initial_seed()}')

    common_dtype = torch.float32

    device = config.device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data loading
    from loaders.ravdess_stills import RAVDESSStillsDataset
    dataset = RAVDESSStillsDataset(config.dataset)

    tn = int(len(dataset) * args.splits[0] / 100)
    vn = int(len(dataset) * args.splits[1] / 100)
    train_ds, valid_ds, test_ds = torch.utils.data.random_split(dataset, (tn, vn, len(dataset) - tn - vn))

    from torch.utils.data import DataLoader
    batch_size = config.batch_size
    dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False)
    validation_dataloader = DataLoader(valid_ds, batch_size=batch_size, num_workers=0, drop_last=False)
    testing_dataloader = DataLoader(test_ds, batch_size=batch_size, num_workers=0, drop_last=False)

    # Target model
    model = EmotionClassifier(dataset.class_count())
    model.to(device)

    # Optimizer
    optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=0.001, betas=(0.9, 0.999),
            eps=1e-08, weight_decay=0.01,
            amsgrad=False)

    # Model and optimizer initialization
    save_dir = os.path.join(output_home, str(int(datetime.now(timezone.utc).timestamp())))
    state = {}
    initial_epoch = 1
    logging.info(f"Training from scratch, snapshotting to {save_dir}")

    Path(save_dir).mkdir(parents=True, exist_ok=True)

    with open(os.path.join(save_dir, 'seed.txt'), 'w') as f:
        f.write(str(torch.initial_seed()))

    # Training metric
    criterion = torch.nn.CrossEntropyLoss()

    from tqdm import tqdm
    save_every = 1000 # batches
    logging.info(f'Snapshots saved every {save_every} batches')

    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir=os.path.join(output_home, 'logs', os.path.basename(save_dir)))

    rc = RuntimeConfig(
        mode='init',
        model=model,
        loss_fn=criterion,
    )

    validation_history = []

    last_epoch = initial_epoch + args.epochs - 1
    for epoch in range(initial_epoch, last_epoch+1):
        running_loss = 0.0
        rc.mode = 'training'
        rc.epoch = epoch
        with tqdm(total=len(dataloader)) as pbar:
            pbar.set_description(f'Epoch {epoch}/{last_epoch}')
            for idx, batch in enumerate(dataloader):
                rc.batch = idx

                stills = batch['still'].type(common_dtype).to(device)
                expected = batch['label'].to(device)

                _loss = apply(rc, optimizer, expected, stills)

                running_loss = (_loss + running_loss * idx) / (idx + 1)

                if idx % save_every == (save_every-1):
                    state = {
                            'epoch': epoch,
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            }
                    batch_block = str(idx+1).zfill(6)
                    epoch_block = str(epoch).zfill(3)
                    torch.save(state, os.path.join(save_dir, f'ec_{epoch_block}_{batch_block}.pth'))

                pbar.set_postfix(loss=running_loss)
                pbar.update(1)

                if idx % 100 == 99 or idx == len(validation_dataloader) - 1:
                    writer.add_scalar('loss/train', running_loss, len(dataloader)*(epoch-1) + idx)

        state = {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                }
        epoch_block = str(epoch).zfill(3)
        torch.save(state, os.path.join(save_dir, f'ec_{epoch_block}_complete.pth'))

        with torch.no_grad():
            try:
                rc.mode = 'validation'
                validation_loss = 0.0
                with tqdm(total=len(validation_dataloader)) as pbar:
                    pbar.set_description(f'Validation on epoch {epoch}')
                    for idx, batch in enumerate(validation_dataloader):
                        rc.batch = idx

                        stills = batch['still'].type(common_dtype).to(device)
                        expected = batch['label'].to(device)
                        _loss = apply(rc, None, expected, stills)

                        validation_loss = (_loss + validation_loss * idx) / (idx + 1)

                        pbar.set_postfix(vloss=validation_loss)
                        pbar.update(1)

                        if idx % 100 == 99 or idx == len(validation_dataloader) - 1:
                            writer.add_scalar('loss/valid', validation_loss, len(validation_dataloader)*(epoch-1) + idx)

                validation_history.append(validation_loss)
                if len(validation_history) > 3 and reduce(lambda x,y: x and y, [x <= validation_loss for x in validation_history[-3:]]):
                    raise StopIteration()
            except StopIteration:
                logging.info(f"Early stopping after validation loss not decreasing for 3 epochs.")
                break

    with torch.no_grad():
        rc.mode = 'testing'
        testing_loss = 0.0
        accuracy_results = 0
        accuracy_frames = 0
        artifacts_path = os.path.join(output_home, 'test_samples')
        Path(artifacts_path).mkdir(parents=True, exist_ok=True)
        with tqdm(total=len(testing_dataloader)) as pbar:
            pbar.set_description(f'Testing best model')
            for idx, batch in enumerate(testing_dataloader):
                rc.batch = idx

                stills = batch['still'].type(common_dtype).to(device)
                expected = batch['label'].to(device)

                _loss, matches = apply(rc, None, expected, stills, test=True)

                testing_loss = (_loss + testing_loss * idx) / (idx + 1)

                accuracy_frames += len(stills)
                accuracy_results += matches

                pbar.set_postfix(tloss=testing_loss)
                pbar.update(1)

                if idx % 10 == 9 or idx == len(testing_dataloader) - 1:
                    writer.add_scalar('loss/test', testing_loss, idx)
                    [
                        Image.fromarray(stills[fidx].cpu().numpy().astype(numpy.uint8).transpose((1,2,0)), mode='RGB').save(os.path.join(artifacts_path, f"{idx}_{fidx}.jpg"))
                        for fidx in range(len(stills))
                    ]

                _accuracy = (accuracy_results / accuracy_frames) * 100.0
                writer.add_scalar('acc/plain', _accuracy, idx)
                writer.add_scalar('acc/batches', accuracy_frames, idx)
                writer.add_scalar('acc/matches', accuracy_results, idx)
            print(f'Final accuracy on test set: {_accuracy}')

def infer(dataset, image):
    common_dtype = torch.float32
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    from loaders.ravdess_stills import RAVDESSStillsDataset
    dataset = RAVDESSStillsDataset(dataset)

    with torch.no_grad():
        model = EmotionClassifier(dataset.class_count())

        model_file = last_snapshot_for(output_home)
        print(f"Emotion classifier will load from {model_file}")
        state = torch.load(model_file, map_location=device)
        model.load_state_dict(state['model'])
        model.eval()
        model.to(device)

        img = numpy.array(Image.open(image)).transpose((2,0,1)) / 255.0
        imput = torch.Tensor(img).unsqueeze(0)

        predictions = model(imput.to(device))
        class_idx = torch.argmax(predictions).item()

    return dataset.label_for(class_idx)

if __name__ == '__main__':
    parser = argparse.ArgumentParser("EC runner")
    # Tried and did not success (yet) with subparsers.
    parser.add_argument("command",
            choices=["train", "infer"],
            help=f"Choose a run mode.")
    parser.add_argument("dataset",
            help="Path to a folder containing either the RAVDESS ZIP archives, or the binary produced by a Coach's RAVDESS handler (ravdess_stills_dataset_precompute.bin).")
    parser.add_argument("--image",
            required=False,
            default=None,
            help="Image to run inference on. Ignored if command is train, required if command is infer.")
    parser.add_argument("--batch-size",
            dest='batch_size',
            required=False,
            type=int,
            default=64,
            help="Training batch size.")
    parser.add_argument("--epochs",
            required=False,
            type=int,
            default=10,
            help="Number of epochs to train for.")
    parser.add_argument("--splits",
            required=False,
            action=SplitsParser,
            default=(85,5,10),
            help="Data point splits training/validation/testing. Data points are slices of AV files, not files themselves. Defaults to (85, 5, 10)%% for each. Numbers must add to 100.")
    parser.add_argument("--seed",
            required=False,
            type=int,
            default=None,
            help="Seed for PyTorch and Numpy. It will affect the data split allocations.")
    parser.add_argument("--device",
            choices=["cpu", "cuda"],
            required=False,
            default=None,
            help="Force a device, either cpu or cuda.")
    args = parser.parse_args()


    if args.command == 'train':
        train(args)
    elif args.command == 'infer':
        if args.image:
            print(infer(args.dataset, args.image))
        else:
            raise ValueError('Missing image file to run inference on.')
