import torch
import copy
from dataclasses import dataclass
from datetime import datetime

from utils.utils import (
    get_permuted_idx, deterministic_NeuralSort, 
    gumbel_softmax, soft_nn, get_closest,
    set_seed, parse_args, DefaultArgs, compute_acc
)
from utils.data_gen import generate_data
from models.models import MLPQueryModel, DatastructureModel

@dataclass
class RunState:
    args: dict = None
    query_model_state_last = None
    query_model_state_best_hard = None
    optimizer_state_last = None
    optimizer_state_best_hard = None
    data_model_state_last = None
    data_model_state_best_hard = None
    extra_model_state_last = None
    extra_model_state_best_hard = None
    step_best: int = 0
    step_best_hard: int = 0
    step_last: int = 0
    loss_best: float = float('inf')
    loss_best_hard: float = float('inf')

@dataclass
class Args(DefaultArgs):
    # Basic training parameters
    batch_size: int = 1024
    eval_batch_size: int = 5000
    n_steps: int = 2000000
    weight_decay: float = 0.0001

    # Learning rates
    query_model_lr: float = 0.00005
    datastructure_model_lr: float = 0.00005
    extra_model_lr: float = 0.005

    # Query MLP architecture
    query_mlp_hidden_dim: int = 1024
    query_mlp_num_layers: int = 3
    query_mlp_act_fn: str = 'relu'

    # Datamodel Transformer config
    datamodel_n_layer: int = 4
    datamodel_n_embd: int = 64
    datamodel_n_head: int = 8
    datamodel_arch: str = 'transformer'

    # Task parameters
    n_inputs: int = 50
    max_queries: int = 5
    data_sample_range: float = 1
    query_sample_range: float = 1
    query_corr: float = 1.0
    flip_ratio: float = 0.2
    input_distribution_type: str = 'uniform'
    nn_is_query: bool = False
    random_query: bool = True
    dim: int = 1
    worst_case_var: float = 7.0
    n_vals: int = 60
    zipf_alpha: float = 1.25
    data_model_sample_rate: int = 4

    # Datamodel config
    use_data_model: bool = True
    extra_space_dim: int = 3
    sort_inputs: bool = False
    n_extra: int = 0
    permute: bool = True

    # Operational flags
    adaptive: bool = True
    nn_closest: bool = True
    soft_nn_temp: float = 1.0
    optimize_ce: bool = True
    gumbel_noise_scale: float = 1.0
    gumbel_temp: float = 2.0
    diff_sort_temp: float = 1.0

    # Train settings
    save_model: bool = True
    save_freq: int = 2000
    log_freq: int = 500
    run_eval_freq: int = 2000

## This function generates a data structure and iteratively queries it to find a response for a given query
def query_and_compute(query_model, data_model, extra_model, inputs, query_data, args: Args, hard=False, n_resample=1):
        permutations = None
        extras = None

        ## First we generate the datastructure
        if data_model is not None:
            ## Generate ranking for each of the points in the input datasets
            ordering, _ = data_model(inputs[:len(inputs) // n_resample])
            ordering = ordering.repeat(n_resample, 1, 1)
            ## Generate permutation matrix based on rankings
            permutations = deterministic_NeuralSort(ordering, tau=args.diff_sort_temp, hard=hard)
            ## Apply permutation
            inputs = permutations.bmm(inputs)
            ## (Optionally) Add extra pre-computed statistics to data structure
            if extra_model is not None:
                extras = extra_model(inputs[:len(inputs) // n_resample])
                extras = extras.repeat(n_resample, 1, 1)
                inputs = torch.cat([inputs, extras], dim=1)

        ## Lookup masks
        masks = []
        ## Lookup values, i.e. value in datastructure stored at corresponding lookup mask position
        values = []


        for i in range(args.max_queries):
            ## Generate lookup position conditioned on previous lookup positions and their values
            mask_i = query_model(query_data, values, masks, i)
            if hard == True: ## Argmax sample creating a 1-hot lookup mask, only for inference time
                mask_i = torch.nn.functional.one_hot(mask_i.argmax(dim=-1), mask_i.size(-1)).float()
            else: ## During training we take a 'soft' mask so gradients flow through nicely
                mask_i = gumbel_softmax(mask_i, temperature=args.gumbel_temp, hard=False, noise_scale=args.gumbel_noise_scale)
                # mask_i = noisy_smax(mask_i, temp=args.gumbel_temp, noise_rate=args.noisy_smax_scale)

            if args.use_data_model and args.n_extra > 0 and i == args.max_queries - 1:
                ## Final query model can only predict one of the original items, not extra inputs. However, we need to resize
                ##  the final mask so that it has the same size as intermediate lookup masks so that we can stack tensors nicely.
                mask_i = torch.cat([mask_i, torch.zeros(mask_i.size(0), args.n_extra).to(mask_i.device)], dim=1)

            ## Apply lookup to recover value
            value_i = torch.bmm(mask_i[:, None], inputs).squeeze(1)
            masks.append(mask_i)
            values.append(value_i)

        values = torch.stack(values, dim=1)
        masks = torch.stack(masks, dim=1)

        if args.adaptive == False or args.nn_closest == True:
            ## Select the value closest to the query as the final value
            final_value, scores = soft_nn(query_data, values, temp=args.soft_nn_temp)
            final_mask_idx = scores.argmax(-1)
            final_mask = masks[torch.arange(masks.size(0)), final_mask_idx]
            values = torch.cat([values, final_value.unsqueeze(1)], dim=1)
            masks = torch.cat([masks, final_mask.unsqueeze(1)], dim=1)
        return values, masks, inputs, permutations

def train_model(run_state: RunState, args: Args, optimizer, data_model, query_model, extra_model, saved_datasets):
    step_start = datetime.now()
    for step in range(args.n_steps):

        ## We first generate data points, a set of queries and their nearest neighbors
        input_data, query_data, y, nn_ids = generate_data(
            args.n_inputs, 
            args.batch_size, 
            args.dim, 
            args.data_sample_range, 
            args.query_sample_range, 
            dist_type=args.input_distribution_type, 
            nn_is_query=args.nn_is_query, 
            device=args.device, 
            n_resample=args.data_model_sample_rate, 
            sort_inputs=args.sort_inputs, 
            random_query=args.random_query, 
            query_corr=args.query_corr, 
            flip_ratio=args.flip_ratio, 
            worst_case_var=args.worst_case_var, 
            n_vals=args.n_vals, 
            zipf_alpha=args.zipf_alpha, 
            saved_data=saved_datasets['train']
        )
        
        ## Now we generate the dataset and execute the queries
        values, masks, transformed_inputs, permutations = query_and_compute(
            query_model, data_model, extra_model, 
            input_data, query_data, args, n_resample=args.data_model_sample_rate
            )

        if args.use_data_model and args.permute == True:
            final_one_hot_mask_acc = compute_acc(masks[:, -1], get_permuted_idx(permutations, nn_ids))
        else:
            final_one_hot_mask_acc = compute_acc(masks[:, -1], nn_ids)
        
        mse_loss_all_queries = ((masks[..., None]*transformed_inputs[:, None]).sum(dim=-2, keepdim=False) - y.unsqueeze(1)).square().mean(dim=[0, 2])
        loss_all_queries = mse_loss_all_queries

        loss_final_query = loss_all_queries[-1]
        if args.optimize_ce:
            log_values = masks[:, -1, :args.n_inputs].log()
            if permutations is not None:
                targets = permutations[torch.arange(len(masks)), :, nn_ids]
            else:
                targets = torch.nn.functional.one_hot(nn_ids, num_classes=masks.size(-1))
            loss = -log_values.mul(targets).sum(-1).mean()
        else:

            loss = loss_final_query

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % args.log_freq == 0:
            time_per_step = (datetime.now() - step_start).total_seconds() / args.log_freq
            step_start = datetime.now()

            ## Run eval using 1-hot lookups (referred to as 'hard')
            inputs, query, target, nn_ids = generate_data(
                args.n_inputs, 
                args.eval_batch_size, 
                args.dim, 
                args.data_sample_range, 
                args.query_sample_range, 
                dist_type=args.input_distribution_type, 
                nn_is_query=args.nn_is_query, 
                device=args.device, 
                sort_inputs=args.sort_inputs, 
                random_query=args.random_query, 
                query_corr=args.query_corr, 
                flip_ratio=args.flip_ratio, 
                worst_case_var=args.worst_case_var, 
                n_vals=args.n_vals, 
                zipf_alpha=args.zipf_alpha,  
                saved_data=saved_datasets['test']
            )
            with torch.no_grad():
                old_sort_temp = args.diff_sort_temp
                old_soft_nn_temp = args.soft_nn_temp
                args.diff_sort_temp = 0.0000000001
                args.soft_nn_temp =   0.0000000001
                values, masks, transformed_inputs, permutations = query_and_compute(
                    query_model, data_model, extra_model, 
                    inputs, query, args, hard=True)
                args.diff_sort_temp = old_sort_temp
                args.soft_nn_temp = old_soft_nn_temp
            final_mask_pos = masks.argmax(-1)
            if args.use_data_model and args.permute == True:
                nn_permuted_idx = get_permuted_idx(permutations, nn_ids)
            else:
                nn_permuted_idx = nn_ids
            hard_accs = (nn_permuted_idx[:, None] == final_mask_pos).float().mean(0)
            hard_loss = (values - target.unsqueeze(1)).square().mean(dim=[0, 2])[-1].item()

            if args.use_data_model and args.permute == True:
                nn_permuted_idx = get_permuted_idx(permutations, nn_ids)
            else:
                nn_permuted_idx = nn_ids
            _, closest_masks = get_closest(values, query, masks)
            closest_hard_acc = (closest_masks.argmax(-1) == nn_permuted_idx[:, None]).float().mean(0)[-1].item()

            print(
                f'Step {step}, '
                f'Total Loss: {loss.item():.3f}, '
                f'MSE Loss Final Query: {mse_loss_all_queries[-1].item():.3f}, '
                f'Mask Acc: {final_one_hot_mask_acc.item():.2f}, '
                f'Hard Acc: {hard_accs[-1].item():.2f}, '
                f'Closest Acc: {closest_hard_acc:.2f}, '
                f'TPS: {time_per_step:.4f}'
            )

            if hard_loss < run_state.loss_best_hard:
                run_state.loss_best_hard = hard_loss
                run_state.step_best_hard = step
                run_state.query_model_state_best_hard = copy.deepcopy(query_model.state_dict())
                if data_model is not None:
                    run_state.data_model_state_best_hard = copy.deepcopy(data_model.state_dict())
                if extra_model is not None:
                    run_state.extra_model_state_best_hard = copy.deepcopy(extra_model.state_dict())
                run_state.optimizer_state_best_hard = copy.deepcopy(optimizer.state_dict())
            if step % args.save_freq == 0 and args.save_model:
                run_state.step_last = step
                run_state.query_model_state_last = query_model.state_dict()
                run_state.data_model_state_last = data_model.state_dict() if data_model is not None else None
                run_state.extra_model_state_last = extra_model.state_dict() if extra_model is not None else None
                run_state.optimizer_state_last = optimizer.state_dict()
                save_run_state(run_state)


    run_state.step_last = step
    run_state.query_model_state_last = query_model.state_dict()
    run_state.data_model_state_last = data_model.state_dict() if data_model else None
    run_state.extra_model_state_last = extra_model.state_dict() if extra_model else None
    run_state.optimizer_state_last = optimizer.state_dict()


def save_run_state(run_state):
    torch.save(run_state, run_state.args.out_dir/'out.pk')

def create_optimizer(args: Args, query_model, data_model, extra_model):
    params = [
        {"params": query_model.parameters(), "lr": args.query_model_lr},
    ]
    if data_model is not None:
        params.append({"params": data_model.parameters(), "lr": args.datastructure_model_lr})
    if extra_model is not None:
        params.append({"params": extra_model.parameters(), "lr": args.extra_model_lr})
    return torch.optim.AdamW(params, weight_decay=args.weight_decay)  

def main():
    args: Args = parse_args(Args)

    run_state = RunState()

    if args.use_data_model == False:
        assert args.n_extra == 0

    run_state.args = args
    print(args)
    set_seed(args.seed)

    if args.save_model:
        args.out_dir.mkdir(parents=True, exist_ok=True)

    # Model initialization
    extra_model = None
    query_model = MLPQueryModel(args).to(args.device)
    if args.use_data_model:
        data_model = DatastructureModel(
            args.n_inputs, 
            args.n_inputs, 
            args.dim, 
            extra_space_dim=args.extra_space_dim, 
            permute=args.permute, 
            n_layer=args.datamodel_n_layer, 
            n_embd=args.datamodel_n_embd, 
            n_head=args.datamodel_n_head, 
            output_dim=(1 if args.permute else args.dim), 
            arch=args.datamodel_arch
        ).to(args.device)
        
        if args.n_extra > 0:
            extra_model = DatastructureModel(
                args.n_inputs, 
                args.n_extra, 
                args.dim, 
                extra_space_dim=args.extra_space_dim, 
                permute=False, 
                n_layer=args.datamodel_n_layer, 
                n_embd=args.datamodel_n_embd, 
                n_head=args.datamodel_n_head, 
                output_dim=args.extra_space_dim, 
                arch=args.datamodel_arch
            ).to(args.device)
    else:
        data_model = None

    optimizer = create_optimizer(args, query_model, data_model, extra_model)

    print(data_model)
    print(extra_model)
    print(query_model)

    if args.input_distribution_type in ['glove', 'sift', 'glovev2', 'glovev3', 'lastfm', 
                                        'glove_norm', 'fashion_mnist', 'sift_no_norm']:
        saved_datasets = torch.load(f'ann_data/{args.input_distribution_type}.pk')
        saved_datasets['train'] = saved_datasets['train'][:, :args.dim]
        saved_datasets['test'] = saved_datasets['test'][:, :args.dim]
    else:
        saved_datasets = {
            'train' : None,
            'test': None,
        }

    train_model(
                run_state,
                args,
                optimizer, 
                data_model, 
                query_model,
                extra_model,
                saved_datasets,
                )
    
    if args.save_model:
        print("Best loss:", run_state.loss_best)
        print(f'Saving model to:\n{args.out_dir/"out.pk"}')
        save_run_state(run_state)

if __name__ == '__main__':
    startTime = datetime.now()
    main()
    print('Total time:', datetime.now() - startTime)
