import argparse
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from sdict import sdict
from simmanager import SimManager
from torch.utils.data import DataLoader
from yaml import Loader

from egru import Timer, get_random_name, Optimizer
from egru.delay_copy_task_timing import DelayCopyDataTiming, get_delay_timing_data_full_batch, \
    get_delay_timing_onehot_data_full_batch
from egru.egruc import EGRUC, convolve_outputs, convolve_dynamics
from egru.models import EGRUThresholdInit


def main(c, p):
    print('Seed: ', c.seed)

    torch.manual_seed(c.seed)
    np.random.seed(c.seed)

    device = torch.device("cpu")
    if c.cuda:
        device = torch.device("cuda:0")
    input_spacing_factor = 0.1
    if not c.minibatch_data:
        print("Using full batch data")
        if c.one_hot_encoding:
            input_times, inputs, target_times, targets = get_delay_timing_onehot_data_full_batch(c.dt, c.length,
                                                                                                 c.width,
                                                                                                 c.initial_delay,
                                                                                                 c.delay,
                                                                                                 input_spacing_factor=input_spacing_factor,
                                                                                                 output_repeat_factor=c.output_repeat_factor)
        else:
            input_times, inputs, target_times, targets = get_delay_timing_data_full_batch(c.dt, c.length, c.width,
                                                                                          c.initial_delay,
                                                                                          c.delay,
                                                                                          input_spacing_factor=input_spacing_factor,
                                                                                          output_repeat_factor=c.output_repeat_factor)
        batch_size = input_times.shape[0]
        input_times = torch.from_numpy(input_times).to(device)
        inputs = torch.from_numpy(inputs).to(device)
        target_times = torch.from_numpy(target_times).to(device)
        targets = torch.from_numpy(targets).to(device)

    else:
        print("Sampling from dataset")
        dataset = DelayCopyDataTiming(c.dt, c.seed, c.length, c.width, c.initial_delay, c.delay, c.batch_size,
                                      input_spacing_factor=input_spacing_factor,
                                      binary_encoding=(not c.one_hot_encoding),
                                      output_repeat_factor=c.output_repeat_factor)
        loader = DataLoader(dataset, batch_size=c.batch_size, num_workers=2, pin_memory=c.cuda)
        data_iter = iter(loader)
        batch_size = c.batch_size
    print('Batch/dataset size: ', batch_size)
    input_size = c.total_input_width
    output_size = c.target_width

    frac_out_units = c.frac_out_units
    model = EGRUC(input_size=input_size, output_size=output_size, n_units=c.n_units, frac_out_units=frac_out_units,
                  thr_init=c.thr_init, bias_std=c.bias_std, batch_size=batch_size)

    if c.load:
        print(f"LOADING MODEL from {c.load}")
        model.load_state_dict(torch.load(c.load, map_location=torch.device('cpu')))
        model.eval()
    else:
        model.train()
    print("Torchdiffeq based model")

    model = model.to(device)

    if c.one_hot_encoding:
        loss_function = nn.CrossEntropyLoss()
    else:
        loss_function = nn.BCEWithLogitsLoss()

    if c.optimizer == Optimizer.sgd:
        print("Using SGD")
        optimizer = optim.SGD(list(model.parameters()), lr=c.learning_rate)
    elif c.optimizer == Optimizer.adam:
        print("Using Adam")
        optimizer = optim.Adam(list(model.parameters()), lr=c.learning_rate)
    else:
        raise RuntimeError(f"Unknown optimizer {c.optimizer}")

    if p:
        torch.save(model.state_dict(), os.path.join(p.results_path, f'models/egurc-init.pt'))

    running_avg_bitwise_success_rate = 0.
    for it in range(c.n_training_iterations):
        print(f"Iteration {it}")
        model.init_hidden()
        initial_state = model.get_initial_state()

        if c.minibatch_data:
            data = next(data_iter)
            input_times = data['input_times'].to(device)
            inputs = data['inputs'].to(device)
            target_times = data['target_times'].to(device)
            targets = data['targets'].to(device)

        with Timer() as bt:
            net_taus, c_ts, i_us, i_rs, i_cs, out_taus, couts, hs_t, all_taus = model(initial_state, input_times,
                                                                                      inputs,
                                                                                      target_times)

            mean_activity = torch.sum(hs_t / model.model.thr) / batch_size

            if mean_activity == 0.:
                print("No more spikes. Learning has stopped")
                break

            selected_conv_values_arr = convolve_outputs(target_times, hs_t, all_taus, net_taus)
            relevant_outputs = selected_conv_values_arr[..., :output_size]
            # relevant_outputs = model.hidden2out(selected_conv_values_arr)

            conv_values_list, times_list = convolve_dynamics(hs_t, all_taus, net_taus)

            if c.one_hot_encoding:
                # For CrossEntropyLoss, classes should be dim 1
                relevant_outputs_ = torch.transpose(relevant_outputs, 1, 2)
                targets_ = torch.transpose(targets, 1, 2)
                loss = loss_function(relevant_outputs_, torch.argmax(targets_, dim=1))
            else:
                loss = loss_function(relevant_outputs, targets)

            reg_loss = (torch.mean(torch.cat(conv_values_list))) ** 2
            total_loss = loss + c.activity_regularization_constant * reg_loss
        print(f"Forward pass Batch time was {bt.difftime:.4f}.")

        if not c.load:
            with Timer() as bt:
                optimizer.zero_grad()
                total_loss.backward()

                optimizer.step()
            print(f"Backward pass Batch time was {bt.difftime:.4f}.")

            zero_grads = []
            non_zero_grads = []
            for (name, param) in model.model.named_parameters():
                if name == 'thr': continue
                # print(name)
                pg = param.grad.data
                # print((pg == 0.).all())
                if (pg == 0.).all():
                    zero_grads.append(name)
                else:
                    non_zero_grads.append(name)

            print(f"{it} :: Parameters with zero grads: {zero_grads}")
            print(f"{it} :: Parameters with non-zero grads: {non_zero_grads}")
        else:
            print("loaded model. Not training")

        if c.one_hot_encoding:
            actual_output = torch.argmax(relevant_outputs, -1)
            bitwise_success_rate = (actual_output == torch.argmax(targets, -1)).float().sum() / (
                torch.numel(actual_output))
        else:
            actual_output = torch.where(torch.sigmoid(relevant_outputs) < 0.5, torch.zeros_like(relevant_outputs),
                                        torch.ones_like(relevant_outputs))
            bitwise_success_rate = (actual_output == targets).float().sum() / (torch.numel(targets))
        running_avg_bitwise_success_rate += bitwise_success_rate.data.item()
        running_avg_bitwise_success_rate /= 2

        print(f"Training iteration {it} :: Loss is {loss.data.item():.4f} ::"
              f" Reg loss is {reg_loss.data.item():.4f} ::"
              f" Bitwise success rate {bitwise_success_rate.data.item():.4f}  (Running avg.  {running_avg_bitwise_success_rate:.4f}) ::"
              f" Mean activity {mean_activity.data.item():.4f} :: "
              )
        if p:
            torch.save(model.state_dict(), os.path.join(p.results_path, f'models/egurc-{it}.pt'))

        if c.plot:
            bi = 0

            import matplotlib.pyplot as plt
            fig, axs = plt.subplots(2, 1)

            valid_net_taus = torch.masked_select(net_taus[:, bi], torch.isfinite(net_taus[:, bi]))
            valid_cts = torch.masked_select(c_ts[:, bi], torch.isfinite(net_taus[:, bi])[..., None]) \
                .reshape(-1, c_ts.shape[-1])

            pnts = valid_net_taus.detach().numpy()
            pcts = valid_cts.detach().numpy()
            ax = axs[0]
            ax.plot(pnts, pcts, marker='x')

            ptots = out_taus.detach().numpy()
            pcots = couts.detach().numpy()
            ax = axs[1]
            ax.plot(ptots[:, bi], pcots[:, bi], marker='x')
            plt.show()

        if running_avg_bitwise_success_rate > 0.98:
            print(
                f"Training iteration {it} :: Loss is {loss.data.item():.4f} :: Running avg. of bitwise success rate is high enough {running_avg_bitwise_success_rate:.4f}. Stopping training.")
            break
        if not c.minibatch_data and bitwise_success_rate > 0.98:
            print(
                f"Training iteration {it} :: Loss is {loss.data.item():.4f} :: Single step bitwise success rate is high enough {bitwise_success_rate:.4f} for full batch data. Stopping training.")
            break
        print("DONE")


if __name__ == '__main__':
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--seed', type=int, default=5)
    argparser.add_argument('--load', type=str)
    argparser.add_argument('--batch-size', type=int, default=10)
    argparser.add_argument('--dt', type=float, default=1.)
    argparser.add_argument('--learning-rate', type=float, default=0.1)
    argparser.add_argument('--optimizer', type=str, default='adam', choices=[o.value for o in Optimizer])
    argparser.add_argument('--initial-delay', type=int, required=True)  # Delay before string presentation.
    argparser.add_argument('--num-units', type=int, default=10)  # Delay before string presentation.
    argparser.add_argument('--bias-std', type=float, default=0.2)
    argparser.add_argument('--output-fraction', type=float, default=1.)  # Delay before string presentation.
    argparser.add_argument('--delay', type=int, required=True)  # Delay after string presentation.
    argparser.add_argument('--minibatch-data', action='store_true')  # Delay after string presentation.
    argparser.add_argument('--cuda', action='store_true')
    argparser.add_argument('--train-iter', type=int, default=100)
    argparser.add_argument('--input-length', type=int, default=1)
    argparser.add_argument('--input-width', type=int, default=2)
    argparser.add_argument('--output-repeat-factor', type=int, default=1)
    argparser.add_argument('--activity-reg-constant', type=float, default=1.)
    argparser.add_argument('--thr-init', type=str, default='const-scalar',
                           choices=[e.value for e in EGRUThresholdInit])
    argparser.add_argument('--one-hot-encoding', action='store_true')
    argparser.add_argument('--debug', action='store_true')
    argparser.add_argument('--plot', action='store_true')
    argparser.add_argument('--nostore', action='store_true', help='Nothing is stored on disk')
    args = argparser.parse_args()

    if torch.cuda.is_available():
        if not args.cuda:
            print("WARNING: You have a CUDA device, so you should probably run with --cuda")
        else:
            torch.cuda.manual_seed(args.seed)


    # START CONFIG
    def get_config():
        print('Generating dictionary of parameters')
        # General
        seed = args.seed

        # Task specific parameters
        initial_delay = args.initial_delay
        delay = args.delay
        length = args.input_length
        width = args.input_width  # 8 # in bits if binary encoding, number of symbols otherwise, for example 8

        # (LSTM) Network parameters
        n_training_iterations = args.train_iter
        n_testing_iterations = 2  # 100

        total_input_width = width + 1
        total_input_length = 2 * length + initial_delay + delay
        target_length = length
        target_width = width

        n_units = args.num_units  # 10 # 32

        for o in Optimizer:
            if args.optimizer == o.value:
                optimizer = o
                break
        else:
            raise RuntimeError(f"Unknown value {args.optimizer}")

        # Convert string argument to enum
        for e in EGRUThresholdInit:
            if args.thr_init == e.value:
                thr_init = e
                break
        else:
            raise RuntimeError(f"Unknown value {args.thr_init}")

        batch_size = args.batch_size
        if args.debug:
            print("!!DEBUG!!")
            n_training_iterations = 10
            n_testing_iterations = 10

        config = dict(
            plot=args.plot,
            load=args.load,
            dt=args.dt,
            n_training_iterations=n_training_iterations,
            n_testing_iterations=n_testing_iterations,
            seed=seed,
            cuda=args.cuda,
            length=length,
            width=width,
            minibatch_data=args.minibatch_data,
            initial_delay=initial_delay,
            one_hot_encoding=args.one_hot_encoding,
            output_repeat_factor=args.output_repeat_factor,
            delay=delay,
            bias_std=args.bias_std,
            batch_size=batch_size,
            n_units=n_units,
            frac_out_units=args.output_fraction,
            learning_rate=args.learning_rate,
            optimizer=optimizer,
            total_input_width=total_input_width,
            total_input_length=total_input_length,
            target_length=target_length,
            target_width=target_width,
            thr_init=thr_init,
            activity_regularization_constant=args.activity_reg_constant,
        )
        print(config)
        return config


    ## END CONFIG
    config = get_config()

    if args.load:
        print(f"Loading config from {args.load}")
        with open(os.path.join(os.path.dirname(args.load), '../..', 'data', 'config.yaml'), 'r') as f:
            loaded_config = yaml.load(f, Loader=Loader)
            for k, v in loaded_config.items():
                if not k in ['cuda', 'debug']:
                    config[k] = v
    config = sdict(config)

    ## START DIR NAMES
    rroot = os.path.expanduser(os.path.join('~', 'output'))
    data_path = './data'

    print(rroot)
    root_dir = os.path.join(rroot, 'egruc')
    if args.debug:
        root_dir = os.path.join(rroot, 'tmp')  # NOTE: DEBUG
    os.makedirs(root_dir, exist_ok=True)
    sim_name = get_random_name(prefix='egruc')
    if args.nostore:
        paths = None
        print('Calling main')
        if args.debug:
            from ipdb import launch_ipdb_on_exception

            with launch_ipdb_on_exception():
                main(config, paths)
        else:
            main(config, paths)
        print("No results stored.")
    else:
        with SimManager(sim_name, root_dir, write_protect_dirs=False, tee_stdx_to='output.log') as simman:
            paths = simman.paths
            print("Results will be stored in ", paths.results_path)
            os.makedirs(os.path.join(paths.results_path, 'models'), exist_ok=True)
            with open(os.path.join(paths.data_path, 'config.yaml'), 'w') as f:
                yaml.dump(config.todict(), f, allow_unicode=True, default_flow_style=False)

            print('Calling main')
            if args.debug:
                main(config, paths)
            else:
                main(config, paths)
            print("Results stored in ", paths.results_path)
