import numpy as np
import argparse
import torch
import torch.nn as nn
from torch.optim import Adam
import os
import random

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.set_cmap('Greys_r')

from model import Model
from data import dataset

parser = argparse.ArgumentParser(description='Toy Task for Transformer')
parser.add_argument('--dim', type=int, default=64)
parser.add_argument('--search-dim', type=int, default=64)
parser.add_argument('--value-dim', type=int, default=64)
parser.add_argument('--search', type=int, default=2)
parser.add_argument('--retrieve', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--seq-len', type=int, default=10)
parser.add_argument('--iterations', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--nonlinear', action='store_true', default=False)
parser.add_argument('--concat', action='store_true', default=False)
parser.add_argument('--no-bias', action='store_true', default=False)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--model', type=str, default='Standard', choices=('Standard', 'Compositional', 'Compositional-dot'))
parser.add_argument('--v-p', type=int, default=2)
parser.add_argument('--v-s', type=int, default=2)
parser.add_argument('--gumbel', action='store_true', default=False)
parser.add_argument('--separate', action='store_true', default=False)
args = parser.parse_args()

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)

set_seed(100)

if args.nonlinear:
    name = f'Trained_Models_Nonlinear/{args.seq_len}_{args.v_s}_{args.v_p}/{args.seed}'
else:
    name = f'Trained_Models_Linear/{args.seq_len}_{args.v_s}_{args.v_p}/{args.seed}'

name = f'{name}/{args.model}_{args.dim}_{args.search_dim}_{args.value_dim}_{args.search}_{args.retrieve}'
if args.gumbel:
    name = f'{name}_gumbel'

if args.concat:
    name = f'{name}_concat'

if args.separate and args.model == 'Compositional-dot':
    name = f'{name}_separate'

if not os.path.exists(name):
    print("Model does not exist")
    exit()

in_dim = args.v_s + args.v_p + args.v_s * args.v_p

model = Model(
    in_dim=in_dim,
    dim=args.dim,
    search_dim=args.search_dim,
    value_dim=args.value_dim,
    model=args.model,
    search=args.search,
    retrieve=args.retrieve,
    nonlinear=args.nonlinear,
    bias=not args.no_bias,
    gumbel=args.gumbel,
    concat=args.concat,
    separate=args.separate
    ).cuda()

num_params = sum(p.numel() for p in model.parameters())

# print(model)
print(f"Number of Parameters: {num_params}")

model.load_state_dict(torch.load(f'{name}/model.pt'))
criterion = nn.L1Loss()

def save(iteration, data, search, score, f_score=None):
    score = score.detach().cpu().numpy()
    if f_score is not None:
        v_score = f_score.view(-1, args.search * args.retrieve).detach().cpu().numpy()
        f_score = f_score.detach().cpu().numpy()

    plt.imshow(search, vmin=0., vmax=1.)
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(os.path.join(name, f'iteration_{iteration}_gt_search.png'))
    plt.close()

    if args.model == 'Compositional' or args.model == 'Compositional-dot':
        for i in range(args.search):
            plt.imshow(score[i].squeeze(), vmin=0., vmax=1.)
            plt.colorbar()
            plt.tight_layout()
            plt.savefig(os.path.join(name, f'iteration_{iteration}_search_{i}.png'))
            plt.close()
            plt.imshow(f_score[:,i,:,0], vmin=0., vmax=1.)
            plt.colorbar()
            plt.tight_layout()
            plt.savefig(os.path.join(name, f'iteration_{iteration}_value_search_{i}.png'))
            plt.close()

        x = np.concatenate([v_score[:,:], data[:,-(args.v_p * args.v_s):]], axis=1)
        plt.imshow(x.T, vmin=0., vmax=1.)
        yticks = []
        for i in range(1,args.v_s+1):
            yticks.append(f'Ground Truth Search {i}')
            for _ in range(args.v_p - 1):
                yticks.append('')
        for i in range(1,args.search+1):
            yticks.append(f'Search {i} | Value 1')
            for j in range(2, args.retrieve+1):
                yticks.append(f'Value {j}')

        plt.yticks(ticks=np.arange(args.search * args.retrieve + args.v_s * args.v_p - 1, -1, -1), labels=yticks, rotation=45)
        plt.colorbar()
        plt.tight_layout()
        plt.savefig(os.path.join(name, f'iteration_{iteration}_activation.png'))
        plt.close()
    else:
        for i in range(args.search):
            plt.imshow(score[i], vmin=0., vmax=1.)
            plt.colorbar()
            plt.tight_layout()
            plt.savefig(os.path.join(name, f'iteration_{iteration}_head_{i}.png'))
            plt.close()

    plt.imshow(data[:,-(args.v_s * args.v_p):], vmin=0., vmax=1.)
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(os.path.join(name, f'iteration_{iteration}_task.png'))
    plt.close()

def eval_step(seq_len):
    model.eval()
    total_loss = 0.

    for _ in range(250):
        data, label, _ = dataset(args.batch_size, seq_len, args.v_s, args.v_p)

        data = torch.Tensor(data).cuda()
        label = torch.Tensor(label).cuda().view(-1)

        pred, _, _ = model(data)

        pred = pred.view(-1)
        loss = criterion(pred, label)

        total_loss += loss.item()

    return total_loss / 250.

def plot_step(iteration):
    i = np.random.choice(args.batch_size)

    d, l, search = dataset(args.batch_size, args.seq_len, args.v_s, args.v_p)

    data = torch.Tensor(d).cuda()

    pred, score, f_score = model(data)

    save(iteration, d[i], search[i], score[i],
         f_score[i] if f_score is not None else None)

for s in [10,20,30,40,50]:
    eval_loss = eval_step(s)
    log = f'Sequence Length: {s} | Eval Loss: {eval_loss:.3f}'
    print(log)