import numpy as np
import os
from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from PickleDataset import PickleDataset, transform_noise, transform_set_noise, transform_finite_noise, transform_mask, transform_discrete_noise
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from cache_datasets import cache_data
from models.GTDM_Model import GTDM_Controller, Conv_GTDM_Controller

import argparse
import random
import configargparse
import sys

class Tee:
    def __init__(self, *file_objects):
        self.file_objects = file_objects

    def write(self, message):
        for file in self.file_objects:
            file.write(message)
            file.flush()  # Ensure immediate write

    def flush(self):
        for file in self.file_objects:
            file.flush()


class CosineAnnealer:
    def __init__(self, num_epochs=100, max=5, min=1):
        self.num_epochs = num_epochs
        self.max = max
        self.min = min
    def forward(self, step):
        if step > self.num_epochs:
            step = self.num_epochs
        return (self.max - self.min) * ( np.cos(np.pi/(2 * self.num_epochs) * step) + self.min/(self.max - self.min))

def mseloss(t1, t2):
    sum = 0
    for i in range(len(t1)):
        sum += (t1[i].item() - t2[i].item()) ** 2
    return sum ** 0.5 

def get_args_parser():
    parser = configargparse.ArgumentParser(description='GTDM Controller Training, load config file and override params',
                                           default_config_files=['./configs/configs.yaml'], config_file_parser_class=configargparse.YAMLConfigFileParser)
    # Define the parameters with their default values and types
    parser.add("--base_root", type=str, help="Base directory for datasets")
    parser.add("--cache_dir", type=str, help="Directory to cache datasets")
    parser.add("--valid_mods", type=str, nargs="+", help="List of valid modalities")
    parser.add("--valid_nodes", type=int, nargs="+", help="List of valid nodes")
    parser.add("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
    parser.add("--num_epochs", type=int, default=10, help="Number of epochs to train")
    parser.add("--adapter_hidden_dim", type=int, default=512, help="Dimension of adapter hidden layers")
    parser.add("--batch_size", type=int, default=32, help="Batch size for training")
    parser.add("--save_best_model", type=bool, default=True, help="Save the best model")
    parser.add("--save_every_X_model", type=int, default=5, help="Save model every X epochs")
    parser.add('--total_layers', type=int, default=8, help="How many layers to reduce to")
    parser.add('--seedVal', type=int, default=100, help="Seed for training")
    parser.add('--train_type', type=str, default='continuous', choices=['continuous', 'discrete', 'finite'])

    # Parse arguments from the configuration file and command-line
    args = parser.parse_args()
    data_root = args.base_root + '/train'
    args.trainset = [
        f'{data_root}/mocap.hdf5',
        f'{data_root}/node_1/mmwave.hdf5',
        f'{data_root}/node_2/mmwave.hdf5',
        f'{data_root}/node_3/mmwave.hdf5',
        f'{data_root}/node_1/realsense.hdf5',
        f'{data_root}/node_2/realsense.hdf5',
        f'{data_root}/node_3/realsense.hdf5',
        f'{data_root}/node_1/respeaker.hdf5',
        f'{data_root}/node_2/respeaker.hdf5',
        f'{data_root}/node_3/respeaker.hdf5',
        f'{data_root}/node_1/zed.hdf5',
        f'{data_root}/node_2/zed.hdf5',
        f'{data_root}/node_3/zed.hdf5',
    ]
    data_root = args.base_root + '/val'
    args.valset = [
        f'{data_root}/mocap.hdf5',
        f'{data_root}/node_1/mmwave.hdf5',
        f'{data_root}/node_2/mmwave.hdf5',
        f'{data_root}/node_3/mmwave.hdf5',
        f'{data_root}/node_1/realsense.hdf5',
        f'{data_root}/node_2/realsense.hdf5',
        f'{data_root}/node_3/realsense.hdf5',
        f'{data_root}/node_1/respeaker.hdf5',
        f'{data_root}/node_2/respeaker.hdf5',
        f'{data_root}/node_3/respeaker.hdf5',
        f'{data_root}/node_1/zed.hdf5',
        f'{data_root}/node_2/zed.hdf5',
        f'{data_root}/node_3/zed.hdf5',
    ]
    data_root = args.base_root + '/test'
    args.testset = [
        f'{data_root}/mocap.hdf5',
        f'{data_root}/node_1/mmwave.hdf5',
        f'{data_root}/node_2/mmwave.hdf5',
        f'{data_root}/node_3/mmwave.hdf5',
        f'{data_root}/node_1/realsense.hdf5',
        f'{data_root}/node_2/realsense.hdf5',
        f'{data_root}/node_3/realsense.hdf5',
        f'{data_root}/node_1/respeaker.hdf5',
        f'{data_root}/node_2/respeaker.hdf5',
        f'{data_root}/node_3/respeaker.hdf5',
        f'{data_root}/node_1/zed.hdf5',
        f'{data_root}/node_2/zed.hdf5',
        f'{data_root}/node_3/zed.hdf5',
    ]


    return args



def main(args):
    
    print("Starting training with seed value", args.seedVal)
    torch.backends.cudnn.deterministic = True
    random.seed(args.seedVal)
    torch.manual_seed(args.seedVal)
    torch.cuda.manual_seed(args.seedVal)
    np.random.seed(args.seedVal)
    # Get current date and time to create new training directory within ./logs/ to store model weights
    now = datetime.now()
    dt_string = "Controller_" + str(args.train_type) + '_Layer_' + str(args.total_layers) + '_Seed_' + str(args.seedVal)
    os.mkdir('./logs/' + dt_string)
    cache_data(args) # Runs cacher from the data_configs.py file, will convert hdf5 to pickle if not already done
    
    #PickleDataset inherits from a Pytorch Dataset, creates train and val datasets
    trainset = PickleDataset(args.cache_dir + 'train', args.valid_mods, args.valid_nodes)
    valset = PickleDataset(args.cache_dir + 'val', args.valid_mods, args.valid_nodes)
    batch_size = args.batch_size
    
    #Creates PyTorch dataloaders for train and val 
    train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=20)
    val_dataloader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=20)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")


    # Create the overall model and load on appropriate device
    model = Conv_GTDM_Controller(args.adapter_hidden_dim, valid_mods=args.valid_mods, valid_nodes = args.valid_nodes, total_layers=args.total_layers)

    print(model.load_state_dict(torch.load('./logs/Correct_Noisy_3_Mod/last.pt'), strict=False))
    # model_template = Conv_GTDM_Controller(args.adapter_hidden_dim, valid_mods=args.valid_mods, valid_nodes = args.valid_nodes, total_layers=args.total_layers)

    # model_template.load_state_dict(torch.load('./logs/Conv_Controller_Reference/last.pt'))
    # model.controller = model_template.controller

    model.to(device)
    for param in model.parameters():
        param.requires_grad=False

    for param in model.controller.parameters():
        param.requires_grad = True
   

    # for param in model.vision.parameters():
    #     param.requires_grad = False

    # for param in model.depth.parameters():
    #     param.requires_grad = False

    # for param in model.vision.blocks[0].parameters():
    #     param.requires_grad = True
    # for param in model.depth.blocks[0].parameters():
    #     param.requires_grad = True

    params = [
        {"params": [p for name, p in model.controller.named_parameters() if "output_head" not in name], "lr": args.learning_rate},
        {"params": model.controller.output_head.parameters(), "lr": args.learning_rate},
    ]
    optimizer = Adam(params)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0.01, total_iters=args.num_epochs)
    annealer = CosineAnnealer(25, 2, 1)

    writer = SummaryWriter(log_dir='./logs/' + dt_string) # Implement tensorboard
    
    # Training loop
    for epoch in trange(args.num_epochs, desc="Training"):
        
        batch_num = 0
        epoch_train_loss = 0
        ad_train_loss = 0
        model.train()
        # if epoch % 3 == 2:
        #     model.controller.decrease_model_layers(min_layers=8)
        #print("Changed temperature to ", annealer.forward(epoch))
        #import pdb; pdb.set_trace()
        for batch in train_dataloader:
            batch_num += 1

            train_loss = 0.0
            # Each batch is a dictionary containing all the sensor data, and the ground truth positions
            data, gt_pos = batch['data'], batch['gt_pos']
            # Data itself is a dictionary with keys ('modality', 'node') that points to data of dimension batch_size
            gt_pos = gt_pos.to(device)
            # print('Img 0', data[('img_std', 'img_std')][0])
            # print("Depth 0", data[('depth_std', 'depth_std')][0])
            if args.train_type == 'continuous':
                data, gt_noise = transform_noise(data, args.batch_size, img_std_max=4, depth_std_max=0.75, mmWave_std_max=0.8)
            elif args.train_type == 'finite':
                data, gt_noise = transform_finite_noise(data, args.batch_size, img_std_max=3, depth_std_max=0.75, mmWave_std_max=0.8)
            elif args.train_type == 'discrete':
                data, gt_noise = transform_discrete_noise(data, args.batch_size, img_std_candidates=[0, 1, 2, 3], depth_std_candidates=[0, 0.25, 0.5, 0.75])
            else:
                raise Exception('Invalid test type specified')
            # Perform forward pass
            batch_results, pred_noise = model(data, controller_temperature = 1) #Dictionary
            print(pred_noise[0], gt_noise[0])
            # key is still ('modality', 'node') with a distribution estimated by the model
            for key in batch_results.keys():
                for i in range(len(batch_results[key]['dist'])):
                    # TODO Currently 2D, also introduce hybrid training, use MSE to help convergence at start then use NLL
                    loss_mse = mseloss(torch.squeeze(batch_results[key]['dist'][i].mean), torch.squeeze(gt_pos[i][:, [0, 2]]))
                    pos_neg_log_probs =  -batch_results[key]['dist'][i].log_prob(torch.squeeze(gt_pos[i][:, [0, 2]])) # Computes NLL loss for each node/modality combo
                    train_loss += pos_neg_log_probs + 0.05 * loss_mse # Accumulate all the losses into the batch loss
                    with torch.no_grad():
                        ad_train_loss += loss_mse / (batch_size * len(batch_results.keys()))
  
            train_loss /= (batch_size * len(batch_results.keys())) # Normalize wrt batch size and number of modality node combinations

            with torch.no_grad():
                # Print one sample from the batch to see prediction result and loss
                print('Batch Number', batch_num)
                key = 'early_fusion'
                print('Estimate', batch_results[key]['dist'][0].mean.data, " with cov ",  batch_results[key]['dist'][0].covariance_matrix.data)
                sample_mse_loss =  mseloss(torch.squeeze(batch_results[key]['dist'][0].mean), torch.squeeze(gt_pos[0][:, [0, 2]]))
                sample_nll_loss =  -batch_results[key]['dist'][0].log_prob(torch.squeeze(gt_pos[0][:, [0, 2]]))
                print('\tGT', gt_pos[0], 'with loss', sample_nll_loss + 0.05 * sample_mse_loss)
                #print('\tGT', gt_pos[0], 'with loss', mseloss(batch_results['img', 'node_1']['dist'][0].mean, gt_pos[0][:, 0:2]))
                print('-------------------------------------------------------------------------------------------------------------------------------')
                epoch_train_loss += train_loss

            noise_loss = torch.mean(torch.abs(gt_noise[:, 0] - pred_noise[:, 0])) + torch.mean(torch.abs(gt_noise[:, 1] - pred_noise[:, 1])) * 10 + torch.mean(torch.abs(gt_noise[:, 2] - pred_noise[:, 2])) * 10


            print("Noise loss", noise_loss)
            if epoch < 1: 
                del train_loss
                train_loss = torch.zeros_like(noise_loss).cuda()
                del batch_results
                del loss_mse
                del pos_neg_log_probs
                torch.cuda.empty_cache()
                
            # else:
            train_loss += noise_loss # TODO CHANGE
            train_loss.backward()
            #nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
            optimizer.step() 
            optimizer.zero_grad()

            
               
        #     # Backprop and update
           
            
        
        print('TRAIN LOSS', epoch_train_loss / batch_num)
        scheduler.step()
        print(scheduler.get_last_lr()[0])
        ad_train_loss /= batch_num
        writer.add_scalar("Training loss", epoch_train_loss / batch_num, epoch)

        batch_num = 0
        epoch_val_loss = 0
        with torch.no_grad():
            log_file = open('./logs/' + dt_string + '/validation.txt', "w")
            temp_std_out = sys.stdout
            sys.stdout = Tee(sys.stdout, log_file)
            model.eval()
            for batch in val_dataloader:

                batch_num += 1
                val_loss = 0.0
                # Each batch is a dictionary containing all the sensor data, and the ground truth positions
                data, gt_pos = batch['data'], batch['gt_pos']
                gt_pos = gt_pos.to(device)
                
                if args.train_type == 'continuous':
                    data, _ = transform_noise(data, args.batch_size, img_std_max=4, depth_std_max=0.75)
                elif args.train_type == 'finite':
                    data, _ = transform_finite_noise(data, args.batch_size, img_std_max=3, depth_std_max=0.75, mmWave_std_max=0.8)
                elif args.train_type == 'discrete':
                    data, _ = transform_discrete_noise(data, args.batch_size, img_std_candidates=[0, 1, 2, 3], depth_std_candidates=[0, 0.25, 0.5, 0.75])
                else:
                    raise Exception('Invalid test type specified')
            # Perform forward pass
                # Perform forward pass
                batch_results, pred_noise = model(data, controller_temperature=annealer.forward(epoch)) #Dictionary
                print(pred_noise[0])
                # key is still ('modality', 'node') with a distribution estimated by the model
                for key in batch_results.keys():
                    for i in range(len(batch_results[key]['dist'])):     

                        loss_mse = mseloss(torch.squeeze(batch_results[key]['dist'][i].mean), torch.squeeze(gt_pos[i][:, [0, 2]]))
                        pos_neg_log_probs =  -batch_results[key]['dist'][i].log_prob(torch.squeeze(gt_pos[i][:, [0, 2]])) # Computes NLL loss for each node/modality combo
                        val_loss += loss_mse # Accumulate all the losses into the batch loss
                val_loss /= (len(batch_results[key]['dist']) * len(batch_results.keys())) # Normalize wrt batch size and number of modality node combinations
                epoch_val_loss += val_loss
            epoch_val_loss /= batch_num
            print("Validation loss", epoch_val_loss)
            log_file.close()
            sys.stdout = temp_std_out
            
        with open( './logs/' + dt_string + '/log.txt', 'a') as handle:
            print('Epoch ' + str(epoch) + ' | Train loss ' + str(ad_train_loss) + ' | Val Loss ' + str(epoch_val_loss)
                  , file=handle)
        torch.save(model.state_dict(), './logs/' + dt_string + '/last.pt')
                

if __name__ == '__main__':
    args = get_args_parser()
    main(args)

