from PIL import Image
import numpy as np
import torch
import h5py
import tqdm 
import numpy as np
import torch.nn as nn

from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt 
import robomimic.utils.file_utils as FileUtils
import argparse 
import os 
from torch.utils.tensorboard import SummaryWriter
import time 

from embedder_models import FinalStatePredictionDino
from embedder_datasets import MultiviewDataset
from image_models import VAE

import torchvision 
import shutil 
import json 
import random 

mse_loss = torch.nn.MSELoss()

def sigmoid(z):
    return 1/(1 + np.exp(-z))

def kld_prior(mean, log_var):
    # return -0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
    return torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim = 1), dim = 0)

def make_filmstrip(current_state, reco_states, true_states, save_dir):
    imsize = current_state.shape[-1]
    resizer = torchvision.transforms.Resize((imsize, imsize))

    reco_states = resizer(reco_states) 
    reco_states = torch.clip(reco_states, 0, 1) # this makes it legal, although as the model gets better, it shouldn't need a lot of clipping 
    reco_states = reco_states.detach().cpu().numpy()
    current_state = current_state.detach().cpu().numpy()
    true_states = true_states.detach().cpu().numpy()
  
    reco_states = np.transpose(reco_states, (0, 2, 3, 1))
    true_states = np.transpose(true_states, (0, 2, 3, 1))
    current_state = np.transpose(current_state, (0, 2, 3, 1))
    reco_img = np.concatenate([reco_states[i] for i in range(reco_states.shape[0])], axis = 0) 
    real_img = np.concatenate([true_states[i] for i in range(true_states.shape[0])], axis = 0) 
    current_img = np.concatenate([current_state[i] for i in range(current_state.shape[0])], axis = 0)
    final_img = np.concatenate((current_img, reco_img, real_img), axis = 1)
    plt.imsave(save_dir, final_img)


def get_valid_stats(model, sampler, generator, exp_dir, step, camera = "robot0_eye_in_hand_image"): 
    loss_count = 0
    embedding_list = list()
    mse_loss = torch.nn.MSELoss()

    info = {"overall" : 0, "mse_loss" : 0, "reco_loss" : 0}
    resizer = None 
    for j in tqdm.tqdm(range(50)):
        try:
            sample = next(generator)
        except StopIteration:
            print("wrapping around!")
            generator = iter(sampler)
            sample = next(generator)

        state, action, last_state = prepare(sample[0]), prepare(sample[1]), prepare(sample[2]) 
        with torch.no_grad():      
            z_hat_last, reco_last = model(state, action)
            embedding_list.append(z_hat_last)
            if resizer is None:
                resizer = torchvision.transforms.Resize((last_state[camera].shape[-1], last_state[camera].shape[-1]))
           
            reco_loss = mse_loss(resizer(reco_last), last_state[camera] / 255) # inputs are at 0 to 255 but outputs are 0-1
            info["reco_loss"] += reco_loss.item()

            z_last = model.state_embedding(last_state)
            mse_loss_value = mse_loss(z_last, z_hat_last)
            info["mse_loss"] += mse_loss_value.item() 

            loss =  mse_loss_value + reco_loss

            info["overall"] += loss.item() 

            
        
        if loss_count % 50 == 0 and step % 40 == 0:
            print("Making filmstrip!")
            make_filmstrip(state[camera] / 255, reco_last, last_state[camera] / 255, exp_dir + f"/rc_{step}_{loss_count}.png")
 
        loss_count += 1

    cat_embed = torch.concatenate(embedding_list, dim = 0)
    variances = torch.std(cat_embed, dim = 0).detach().cpu().numpy() # batch variance
    info = {k : v / loss_count for k, v in info.items()}
    print(f"Average Validation Losses: {info}")
    info["min_var"] = np.min(variances)
    info["max_var"] = np.max(variances)
    print(f"Min batch variance is: {np.min(variances)} and maximum is {np.max(variances)}")
    z_hat_mean = z_hat_logvar = 0

    return info, variances, (z_hat_mean, z_hat_logvar), generator # return variances for logging purposes 

def save_scalar_stats(writer, info_dict, epoch, mode = "train"):
    for key, value in info_dict.items():
        writer.add_scalar(f"{mode}/{key}", value, epoch)

def prepare(data, device = "cuda"):
    if data is None:
        return None # for passthroughs 
    
    if type(data) == dict:
        return {k : v.to(device).to(torch.float32) for k, v in data.items()}
    return data.to(device).to(torch.float32)

def loss_scheduler(step, weight, mode = "linear", start = 0, end = 2000):
    if mode == "linear":
        return min(max(0, ((step - start) / (end - start)) * weight), weight)
    if mode == "sigmoid":
        # when this reaches 4, I assume that it's close to 1 
        if step < start:
            return 0
        return weight * sigmoid(-4 + 8 * max(0, ((step - start) / (end - start))))


def main(args):
    ACTION_DIM = 7  #7 #2 #7 
    CAMERA = "third_person" #"third_person" # "image" for the cube environment 

    cameras = [CAMERA] # you can change this 
    padding = True 
    pad_mode = "repeat" #"zeros" #"repeat" # zeros if you're in the delta environment 

    proprio_dim = 15 
    proprio = "proprio" # set to None if you want to exclude propriorception 

    if args.exp_dir is not None and not os.path.isdir(args.exp_dir):
        os.mkdir(args.exp_dir)
    shutil.copy("train_end_state_embedder_dino.py", args.exp_dir + "/train_end_state_embedder_dino.py") 
    shutil.copy("embedder_models.py", args.exp_dir + "/embedder_models.py") 
    shutil.copy("embedder_datasets.py", args.exp_dir + "/embedder_datasets.py")
    with open(args.exp_dir + "/args.json", "w") as f:
        json.dump(vars(args), f)

    model = FinalStatePredictionDino(ACTION_DIM, args.action_chunk_length, cameras=cameras, reconstruction = True, proprio = proprio, proprio_dim = proprio_dim)
    model.to("cuda")
    print(model.trainable_parameters())

    dataset = MultiviewDataset(args.train_hdf5, action_chunk_length = args.action_chunk_length, cameras = cameras, proprio = proprio,
        padding = padding, pad_mode = pad_mode)
    

    sampler = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, sampler=None,
            batch_sampler=None, num_workers=4, collate_fn=None,
            pin_memory=False, drop_last=False, timeout=0,
            worker_init_fn=None, # prefetch_factor=2,
            persistent_workers=False)
    sample_generator = iter(sampler)

    
    valid_dataset = MultiviewDataset(args.test_hdf5, action_chunk_length = args.action_chunk_length, cameras = cameras, proprio = "proprio",
        padding = padding, pad_mode = pad_mode)

    valid_sampler = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, sampler=None,
            batch_sampler=None, num_workers=4, collate_fn=None,
            pin_memory=False, drop_last=False, timeout=0,
            worker_init_fn=None, # prefetch_factor=2,
            persistent_workers=False)
    valid_generator = iter(valid_sampler)

    if args.noised:
        print("I am augmenting the transformer with diffusion action noise!")
        from diffusers.schedulers.scheduling_ddim import DDIMScheduler
        noise_scheduler = DDIMScheduler(
                num_train_timesteps=100,
                beta_schedule="squaredcos_cap_v2",
                clip_sample=True,
                prediction_type="epsilon",
                steps_offset=0,
                set_alpha_to_one=True
        )

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    writer = SummaryWriter(args.exp_dir) #you can specify logging directory
    mse_loss = torch.nn.MSELoss()
    resizer = None 


    for i in range(args.num_epochs):
        info = {"overall" : 0, "mse_loss" : 0, "reco_loss" : 0} #this was misplaced 

        loss_count = 0
        print(f"----------------Training Step {i}------------------")

        for j in tqdm.tqdm(range(100)):
            try:
                sample = next(sample_generator)
            except StopIteration:
                sample_generator = iter(sampler)
                sample = next(sample_generator)
            

            state, action, last_state = prepare(sample[0]), prepare(sample[1]), prepare(sample[2])

            if args.noised:
                m = torch.distributions.geometric.Geometric(0.05 * torch.ones(action.shape[0])) 
                timesteps = torch.clip(m.sample(), 0, 99).long() # this samples from a geometric distribtuion of expected value 20 
                noise = torch.randn(action.shape, device=action.device)
                noised_action = noise_scheduler.add_noise(action, noise, timesteps)

                # chance_of_mask = min(0.5, i / args.num_epochs)
                chance_of_mask = 0.5
                mask = torch.rand(action.shape[0]) < chance_of_mask 
                action[mask] = noised_action[mask] # this just noises a batch 
            
            z_hat_last, reco_last = model(state, action)
            if resizer is None:
                resizer = torchvision.transforms.Resize((last_state[CAMERA].shape[-1], last_state[CAMERA].shape[-1]))
           
            reco_loss = mse_loss(resizer(reco_last), last_state[CAMERA] / 255)
            info["reco_loss"] += reco_loss.item()

            z_last = model.state_embedding(last_state)
            mse_loss_value = mse_loss(z_last, z_hat_last)
            info["mse_loss"] += mse_loss_value.item() 

            loss =  mse_loss_value + reco_loss

            info["overall"] += loss.item() 

            optimizer.zero_grad() #gradients add up, so you must reset
            loss.backward() 
            optimizer.step() #applies change

            loss_count += 1

        info = {k : v / loss_count for k, v in info.items()}
        save_scalar_stats(writer, info, i, "train")

        print("Average training losses: ", info)  
        if i % 5 == 0:  # so we don't have to spend that much time evaluating something 
            print("--------------Evaluating-----------------")
            model.eval()
            stats, embeddings_std, (mean, logvar), valid_generator = get_valid_stats(model, valid_sampler, valid_generator, args.exp_dir, step = i,camera = CAMERA)
            model.train()
            save_scalar_stats(writer, stats, i, "valid")
            writer.add_histogram("valid/embeddings_std", embeddings_std, i)
            writer.add_histogram("valid/mean", mean, i)
            writer.add_histogram("valid/logvar", logvar, i)
        if i % 100 == 0:
            torch.save(model.state_dict(), args.exp_dir + str(i) + ".pth") #saves everything from the state dictionary


        
if __name__ == "__main__":
    torch.set_printoptions(sci_mode = False)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--exp_dir",
        type=str,
        default=None,
        help="",
    )
    parser.add_argument(
        "--train_hdf5",
        type=str,
        default=None,
        help="",
    )

    parser.add_argument(
        "--unlock",
        action='store_true',
        help="unlocks the encoder at a set time",
    )

    parser.add_argument(
        "--noised",
        action='store_true',
        help="unlocks the encoder at a set time",
    )

    parser.add_argument(
        "--test_hdf5",
        type=str,
        default=None,
        help="",
    )

    parser.add_argument(
        "--num_epochs",
        type=int,
        default=None,
        help="",
    )
    parser.add_argument(
        "--action_chunk_length",
        type=int,
        default=None,
        help="",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=16,
        help="",
    )
    args = parser.parse_args()

    main(args)