import torch
import copy
from dataclasses import dataclass
from datetime import datetime

from utils.utils import (
    get_permuted_idx, deterministic_NeuralSort, 
    gumbel_softmax, soft_nn,
    set_seed, parse_args, DefaultArgs
)
from utils.data_gen import generate_mnist_data
from models.models import MLPQueryModel, DatastructureModel, FeatureModel

@dataclass
class RunState:
    args: dict = None
    query_model_state_last = None
    query_model_state_best_hard = None
    optimizer_state_last = None
    optimizer_state_best_hard = None
    data_model_state_last = None
    data_model_state_best_hard = None
    extra_model_state_last = None
    extra_model_state_best_hard = None
    feature_model_state_last = None
    feature_model_state_best_hard = None
    step_best: int = 0
    step_best_hard: int = 0
    step_last: int = 0
    loss_best: float = float('inf')
    loss_best_hard: float = float('inf')

    train_images = None
    val_images = None
    train_feats = None
    val_feats = None

@dataclass
class Args(DefaultArgs):
    # Basic training parameters
    batch_size: int = 1024
    eval_batch_size: int = 5000
    n_steps: int = 2000000
    weight_decay: float = 0.0001

    # Learning rates
    query_model_lr: float = 0.00005
    datastructure_model_lr: float = 0.00005
    extra_model_lr: float = 0.00005
    feature_model_lr: float = 0.0005

    # Query MLP architecture
    query_mlp_hidden_dim: int = 1024
    query_mlp_num_layers: int = 3
    query_mlp_act_fn: str = 'relu'

    # Datamodel Transformer config
    datamodel_n_layer: int = 4
    datamodel_n_embd: int = 64
    datamodel_n_head: int = 8
    datamodel_arch: str = 'transformer'

    # Task parameters
    n_inputs: int = 8
    max_queries: int = 4
    dim: int = 1
    n_vals: int = 16
    n_samples: int = 1
    only_feats: bool = False

    # Datamodel config
    use_data_model: bool = True
    extra_space_dim: int = 3
    sort_inputs: bool = False
    n_extra: int = 0
    permute: bool = True
    data_model_sample_rate: int = 4

    # Operational flags
    adaptive: bool = True
    nn_closest: bool = True
    soft_nn_temp: float = 1.0
    optimize_ce: bool = True
    gumbel_noise_scale: float = 1.0
    gumbel_temp: float = 2.0
    diff_sort_temp: float = 1.0

    # Train settings
    save_model: bool = True
    save_freq: int = 2000
    log_freq: int = 500
    run_eval_freq: int = 2000

## This function generates a data structure and iteratively queries it to find a response for a given query
def query_and_compute(query_model, data_model, extra_model, inputs, query_data, args: Args, hard=False, n_resample=1):
        permutations = None
        extras = None

        ## First we generate the datastructure
        if data_model is not None:
            ## Generate ranking for each of the points in the input datasets
            ordering, _ = data_model(inputs[:len(inputs) // n_resample])
            ordering = ordering.repeat(n_resample, 1, 1)
            ## Generate permutation matrix based on rankings
            permutations = deterministic_NeuralSort(ordering, tau=args.diff_sort_temp, hard=hard)
            ## Apply permutation
            inputs = permutations.bmm(inputs)
            ## (Optionally) Add extra pre-computed statistics to data structure
            if extra_model is not None:
                extras = extra_model(inputs[:len(inputs) // n_resample])
                extras = extras.repeat(n_resample, 1, 1)
                inputs = torch.cat([inputs, extras], dim=1)

        ## Lookup masks
        masks = []
        ## Lookup values, i.e. value in datastructure stored at corresponding lookup mask position
        values = []


        for i in range(args.max_queries):
            ## Generate lookup position conditioned on previous lookup positions and their values
            mask_i = query_model(query_data, values, masks, i)
            if hard == True: ## Argmax sample creating a 1-hot lookup mask, only for inference time
                mask_i = torch.nn.functional.one_hot(mask_i.argmax(dim=-1), mask_i.size(-1)).float()
            else: ## During training we take a 'soft' mask so gradients flow through nicely
                mask_i = gumbel_softmax(mask_i, temperature=args.gumbel_temp, hard=False, noise_scale=args.gumbel_noise_scale)
                # mask_i = noisy_smax(mask_i, temp=args.gumbel_temp, noise_rate=args.noisy_smax_scale)

            if args.use_data_model and args.n_extra > 0 and i == args.max_queries - 1:
                ## Final query model can only predict one of the original items, not extra inputs. However, we need to resize
                ##  the final mask so that it has the same size as intermediate lookup masks so that we can stack tensors nicely.
                mask_i = torch.cat([mask_i, torch.zeros(mask_i.size(0), args.n_extra).to(mask_i.device)], dim=1)

            ## Apply lookup to recover value
            value_i = torch.bmm(mask_i[:, None], inputs).squeeze(1)
            masks.append(mask_i)
            values.append(value_i)

        values = torch.stack(values, dim=1)
        masks = torch.stack(masks, dim=1)

        if args.adaptive == False or args.nn_closest == True:
            ## Select the value closest to the query as the final value
            final_value, scores = soft_nn(query_data, values, temp=args.soft_nn_temp)
            final_mask_idx = scores.argmax(-1)
            final_mask = masks[torch.arange(masks.size(0)), final_mask_idx]
            values = torch.cat([values, final_value.unsqueeze(1)], dim=1)
            masks = torch.cat([masks, final_mask.unsqueeze(1)], dim=1)
        return values, masks, inputs, permutations


def get_loss_and_acc(masks, permutations, nn_ids):
    eps=1e-7
    n_inputs = permutations.size(-1)
    hard_mask_pos = masks.argmax(-1)
    log_preds = masks[:, -1, :n_inputs].add(eps).log()
    if permutations is not None:
        nn_permuted_idx = get_permuted_idx(permutations, nn_ids)
        targets = permutations[torch.arange(len(masks)), :, nn_ids]
    else:
        targets = torch.nn.functional.one_hot(nn_ids, num_classes=masks.size(-1))
        nn_permuted_idx = nn_ids
    loss = -log_preds.mul(targets).sum(-1).mean()
    accs = (nn_permuted_idx[:, None] == hard_mask_pos).float().mean(0)
    return loss, accs

def get_train_val_images(n_vals, n_train_samples, n_val_samples):
    images = torch.load('./data/mnist_200_2000.pt')
    train_images = images['train'][:n_vals, :n_train_samples].squeeze(2)
    val_images = images['test'][:n_vals, :n_val_samples].squeeze(2)

    return train_images, val_images

def train_model(run_state: RunState, args: Args, optimizer, data_model, query_model, extra_model, feature_model):
    step_start = datetime.now()
    for step in range(args.n_steps):

        ## We first generate data points, a set of queries and their nearest neighbors
        inputs, query, y, nn_ids, in_vals, q_vals, t_vals = generate_mnist_data(
            args.n_inputs, 
            args.batch_size, 
            device=args.device, 
            n_resample=args.data_model_sample_rate, 
            sort_inputs=args.sort_inputs, 
            images=run_state.train_images, 
            feature_model=feature_model, 
            n_vals=args.n_vals, 
            n_image_samples=args.n_samples, 
            feats=run_state.train_feats)

        ## Now we generate the dataset and execute the queries
        values, masks, transformed_inputs, permutations = query_and_compute(
            query_model, data_model, extra_model, 
            inputs, query, args, n_resample=args.data_model_sample_rate
            )
        optimizer.zero_grad()
        train_soft_loss, train_soft_accs = get_loss_and_acc(masks, permutations, nn_ids)
        train_soft_loss.backward()
        optimizer.step()
        
        if step % args.log_freq == 0:
            time_per_step = (datetime.now() - step_start).total_seconds() / args.log_freq
            step_start = datetime.now()

            ## Run eval using 1-hot lookups (referred to as 'hard')
            inputs, query, y, nn_ids, in_vals, q_vals, t_vals = generate_mnist_data(
                args.n_inputs, 
                args.batch_size, 
                device=args.device, 
                n_resample=args.data_model_sample_rate, 
                sort_inputs=args.sort_inputs, 
                images=run_state.train_images, 
                feature_model=feature_model, 
                n_vals=args.n_vals, 
                n_image_samples=args.n_samples, 
                feats=run_state.train_feats)
        
            with torch.no_grad():
                old_sort_temp = args.diff_sort_temp
                args.diff_sort_temp = 0.0000001
                values, masks, transformed_inputs, permutations = query_and_compute(
                    query_model, data_model, extra_model, 
                    inputs, query, args, hard=True)
                args.diff_sort_temp = old_sort_temp
            hard_loss, hard_accs = get_loss_and_acc(masks, permutations, nn_ids)
            hard_acc = hard_accs[-1].item()
            val_soft_loss = -hard_acc
            val_hard_loss = -hard_acc

            print(
                f'Step {step}, '
                f'Total Loss: {train_soft_loss.item():.3f}, '
                f'Hard Acc: {hard_acc:.2f}, '
                f'TPS: {time_per_step:.4f}'
            )

            if val_hard_loss < run_state.loss_best_hard:
                run_state.loss_best_hard = val_hard_loss
                run_state.step_best_hard = step
                run_state.query_model_state_best_hard = copy.deepcopy(query_model.state_dict())
                if data_model is not None:
                    run_state.data_model_state_best_hard = copy.deepcopy(data_model.state_dict())
                if extra_model is not None:
                    run_state.extra_model_state_best_hard = copy.deepcopy(extra_model.state_dict())
                if feature_model is not None:
                    run_state.feature_model_state_best_hard = copy.deepcopy(feature_model.state_dict())
                run_state.optimizer_state_best_hard = copy.deepcopy(optimizer.state_dict())
            if step % args.save_freq == 0 and args.save_model:
                run_state.step_last = step
                run_state.query_model_state_last = query_model.state_dict()
                run_state.data_model_state_last = data_model.state_dict() if data_model is not None else None
                run_state.extra_model_state_last = extra_model.state_dict() if extra_model is not None else None
                run_state.feature_model_state_last = feature_model.state_dict() if feature_model is not None else None
                run_state.optimizer_state_last = optimizer.state_dict()
                save_run_state(run_state)


    run_state.step_last = step
    run_state.query_model_state_last = query_model.state_dict()
    run_state.data_model_state_last = data_model.state_dict() if data_model else None
    run_state.extra_model_state_last = extra_model.state_dict() if extra_model else None
    run_state.feature_model_state_last = feature_model.state_dict() if feature_model else None
    run_state.optimizer_state_last = optimizer.state_dict()


def save_run_state(run_state):
    torch.save(run_state, run_state.args.out_dir/'out.pk')

def create_optimizer(args: Args, query_model, data_model, extra_model, feature_model):
    params = [
        {"params": query_model.parameters(), "lr": args.query_model_lr},
    ]
    if data_model is not None:
        params.append({"params": data_model.parameters(), "lr": args.datastructure_model_lr})
    if extra_model is not None:
        params.append({"params": extra_model.parameters(), "lr": args.extra_model_lr})
    params.append({"params": feature_model.parameters(), "lr": args.feature_model_lr})
    return torch.optim.AdamW(params, weight_decay=args.weight_decay)  

def main():
    args: Args = parse_args(Args)

    run_state = RunState()

    if args.use_data_model == False:
        assert args.n_extra == 0

    run_state.args = args
    print(args)
    set_seed(args.seed)

    if args.save_model:
        args.out_dir.mkdir(parents=True, exist_ok=True)

    # Model initialization
    extra_model = None
    query_model = MLPQueryModel(args).to(args.device)
    feature_model = FeatureModel().to(args.device)
    if args.use_data_model:
        data_model = DatastructureModel(
            args.n_inputs, 
            args.n_inputs, 
            args.dim, 
            extra_space_dim=args.extra_space_dim, 
            permute=args.permute, 
            n_layer=args.datamodel_n_layer, 
            n_embd=args.datamodel_n_embd, 
            n_head=args.datamodel_n_head, 
            output_dim=(1 if args.permute else args.dim), 
            arch=args.datamodel_arch
        ).to(args.device)
        
        if args.n_extra > 0:
            extra_model = DatastructureModel(
                args.n_inputs, 
                args.n_extra, 
                args.dim, 
                extra_space_dim=args.extra_space_dim, 
                permute=False, 
                n_layer=args.datamodel_n_layer, 
                n_embd=args.datamodel_n_embd, 
                n_head=args.datamodel_n_head, 
                output_dim=args.extra_space_dim, 
                arch=args.datamodel_arch
            ).to(args.device)
    else:
        data_model = None

    optimizer = create_optimizer(args, query_model, data_model, extra_model, feature_model)

    print(data_model)
    print(extra_model)
    print(query_model)

    run_state.train_images, run_state.val_images = get_train_val_images(args.n_vals, args.n_samples, 100)
    run_state.train_images = run_state.train_images.to(args.device)
    run_state.val_images = run_state.val_images.to(args.device)

    train_model(
                run_state,
                args,
                optimizer, 
                data_model, 
                query_model,
                extra_model,
                feature_model
                )
    
    if args.save_model:
        print("Best loss:", run_state.loss_best)
        print(f'Saving model to:\n{args.out_dir/"out.pk"}')
        save_run_state(run_state)

if __name__ == '__main__':
    startTime = datetime.now()
    main()
    print('Total time:', datetime.now() - startTime)
