import os
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import numpy as np
from utils import Logger, AverageMeter
from utils.dataloader_segments import Dataset
from torch.utils import data

from integrator.base import Solver




class Train:
    
    def __init__(self, args):
        print()
        self.args = args
        self.test = self.args.test_dir != ''
        if self.args.seed is not None:
            print('seed set')
            torch.manual_seed(self.args.seed)
            np.random.seed(self.args.seed)
        
        self.logger = Logger(os.path.join(self.args.ckpt, 'log.txt'), title='')
        if self.test:
            self.logger.set_names(['Learning Rate', 'Train Mse', 'Train RK4 Error', 'Test Mse', 'Test RK4 Error'])
        else:
            self.logger.set_names(['Learning Rate', 'Train Mse', 'Train RK4 Error'])
        #logger.set_names(['Learning Rate', 'Train Mse', 'Train RK4 Error', 'Test Mse', 'Test RK4 Error'])
            
        # Train and test
        training_set = Dataset(
            data_dir = self.args.train_dir, 
            nframe = self.args.train_nframe
            )
        self.trainloader = data.DataLoader(
            training_set, 
            shuffle = True, 
            batch_size = self.args.batch_size, 
            num_workers = 0
            )
        if self.test:
            test_set = Dataset(
                data_dir = self.args.test_dir, 
                nframe = self.args.test_nframe
                )
            self.testloader = data.DataLoader(
                test_set, 
                shuffle = False, 
                batch_size = 128, 
                num_workers = 0
                )
        
        self.n_batch = len(self.trainloader)
        self.criterion = nn.MSELoss(reduction = 'mean')
        self.optimizers = []
        self.schedulers = []
        self.models = []
        self.train_dt = training_set.dt
        self.test_dt = test_set.dt
        self.device = None
        self.train_wt = self.args.beta**torch.arange(1, self.args.train_nframe)
        if self.args.use_wt_in_test:
            self.test_wt = self.args.beta**torch.arange(1, self.args.test_nframe)
        else:
            self.test_wt = torch.ones(self.args.test_nframe-1)
    
    
    def fit(self, integrator: Solver, reference: Solver = None):
        use_ref = reference is not None
        if use_ref and integrator.device != reference.device:
            raise ValueError('integrator and reference must be on the same device.')
        self.device = integrator.device
        self.train_wt = self.train_wt.to(self.device)
        self.test_wt = self.test_wt.to(self.device)
        models = integrator.get_models()  # on integrator.device
        for (i, model) in enumerate(models):
            if self.args.optim == 'sgd':
                optimizer = optim.SGD(model.parameters(), lr = self.args.lr, momentum = 0.9, weight_decay = 0)
            else:
                optimizer = optim.Adam(model.parameters(), lr = self.args.lr, betas = (0.9, 0.999), eps = 1e-8, weight_decay = 0)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.2, patience = 10, cooldown = 5, verbose = True, threshold = 1E-3)
            self.optimizers.append(optimizer)
            self.schedulers.append(scheduler)
        
        # init model parameters or load pretrained parameters
        param_dir = self.args.loadparam
        if param_dir != '__NO_PARAM__':
            if os.path.isfile(param_dir):
                for (i, model) in enumerate(integrator.models):
                    model.load_state_dict(torch.load(param_dir)['model' + str(i)])
                    print('Parameter loaded')
                    if self.args.loadopt:
                        self.optimizers[i].load_state_dict(torch.load(param_dir)['optim' + str(i)])
                        print('Optimization state loaded')
            else:
                raise ValueError(f'Parameter File {param_dir} is not valid.')
        else:
            for (i, model) in enumerate(integrator.models):
                for m in model.modules():
                    if isinstance(m, nn.Linear):
                        nn.init.normal_(m.weight, mean = 0, std = 0.001)
                        #nn.init.normal_(m.bias, mean = 0, std = 0.001)
            print('Parameter initialized')
        
        # training
        printfile = open(self.args.ckpt + '/details.txt', 'w')
        
        state = {k: v for k, v in self.args._get_kwargs()}
        print(state, file = printfile)
        print('\n', file = printfile)
        for epoch in range(self.args.epoch):
            current_iter = 0
            train_mse = AverageMeter()
            train_ref_loss = AverageMeter()
            test_mse = AverageMeter()
            test_ref_loss = AverageMeter()
            
            ## training over batch
            for trajectories in self.trainloader:
                # trajactories has shape (batch, n, n_frame)
                lr = self.optimizers[0].param_groups[0]["lr"]
                out = self.train_batch(trajectories, integrator, reference)
                train_loss, batch_time = out[0], out[1]
                if use_ref:
                    ref_loss = out[2]
                    
                ratio = train_loss / ref_loss * 100 if use_ref else -1
                
                temp = int(np.log10(self.args.epoch + 0.1)) + 1
                temp2 = int(np.log10(self.n_batch + 0.1)) + 1
                
                suffix = f'Training epoch: {epoch+1:{temp}d} iter: {current_iter+1:{temp2}d} / {self.n_batch:d} | lr: {lr:.6f} | Batch: {batch_time:.2f}s | training MSE: {train_loss:.4g} ({ratio:3.2f} %) |'
                print(suffix)
                print(suffix, file = printfile)
        
                train_mse.update(train_loss, torch.numel(trajectories[:, :, 1:]))
                train_ref_loss.update(ref_loss, torch.numel(trajectories[:, :, 1:]))
                current_iter += 1
            # end of all batches
        
            for scheduler in self.schedulers:
                scheduler.step(train_mse.avg)
            
            # testing
            if epoch == self.args.epoch - 1 or (self.args.test_freq > 0 and epoch % self.args.test_freq == 0):
                for trajectories in self.testloader:
                    out = self.test_batch(trajectories, integrator, reference, use_wt = self.args.use_wt_in_test)
                    test_loss, batch_time = out[:2] 
                    if use_ref:
                        ref_loss = out[-1]
                    ratio = test_loss / ref_loss * 100
                    test_mse.update(test_loss, torch.numel(trajectories[:, :, 1:]))
                    test_ref_loss.update(ref_loss, torch.numel(trajectories[:, :, 1:]))
                    
                    suffix = f'Testing epoch: {epoch+1:{temp}d} | Batch: {batch_time:.2f}s | test MSE: {test_loss:.4g} ({ratio:3.2f} %) |'
                    print(suffix)
                    print(suffix, file = printfile)
                
                third = test_mse.avg
                fourth = test_ref_loss.avg
            else:
                third = -1
                fourth = -1
                    
            print()
            print('', file = printfile)
            
            self.logger.append([lr, train_mse.avg, train_ref_loss.avg, third, fourth])
        # end of all epoches
        
        # save model
        state = {}
        for i in range(len(models)):
            state['model' + str(i)] = models[i].state_dict()
            state['optim' + str(i)] = self.optimizers[i].state_dict()
        self.save_checkpoint(state)

        self.logger.close()
        printfile.close()
    
    
    
    def train_batch(self, traj, integrator: Solver, reference: Solver = None):
        # switch to train mode
        for model in integrator.get_models():
            model.train()
        use_ref = reference is not None
        
        # data, traj has shape (bs, n, n_frame)
        traj = traj.to(self.device) 
        traj = traj.permute(1, 0, 2)  # (n, bs, n_frame)
        integrator.set_ic(traj[:, :, 0])
        predict_traj = torch.zeros(
            traj.shape, 
            device = integrator.device)
        if use_ref:
            reference.set_ic(traj[:, :, 0])
            ref_traj = torch.zeros(
                traj.shape, 
                device = reference.device)
        
        start_time = time.time()
        for i in range(1, self.args.train_nframe):
            integrator.step(self.train_dt)
            predict_traj[:, :, i] = integrator.x
            if use_ref:
                reference.step(self.train_dt)
                ref_traj[:, :, i] = reference.x
        
        
        loss = self.criterion(predict_traj[:, :, 1:] * self.train_wt, traj[:, :, 1:] * self.train_wt)
        if use_ref:
            ref_loss = self.criterion(ref_traj[:, :, 1:] * self.train_wt, traj[:, :, 1:] * self.train_wt)
        
        for optimizer in self.optimizers:
            optimizer.zero_grad()
        loss.backward()
        for optimizer in self.optimizers:
            optimizer.step()

        batch_time = time.time() - start_time
        
        if use_ref:
            return loss.item(), batch_time, ref_loss.item()
        return loss.item(), batch_time
    
    
    def test_batch(self, traj, integrator: Solver, reference: Solver = None, use_wt = False):
        # switch to test mode
        for model in integrator.get_models():
            model.eval()
        use_ref = reference is not None
        
        # data, traj has shape (bs, n, n_frame)
        traj = traj.to(self.device) 
        traj = traj.permute(1, 0, 2)  # (n, bs, n_frame)
        integrator.set_ic(traj[:, :, 0])
        predict_traj = torch.zeros(
            traj.shape, 
            device = integrator.device)
        if use_ref:
            reference.set_ic(traj[:, :, 0])
            ref_traj = torch.zeros(
                traj.shape, 
                device = reference.device)
        
        start_time = time.time()
        for i in range(1, self.args.test_nframe):
            integrator.step(self.train_dt)
            predict_traj[:, :, i] = integrator.x
            if use_ref:
                reference.step(self.train_dt)
                ref_traj[:, :, i] = reference.x
        batch_time = time.time() - start_time
        
        loss = self.criterion(predict_traj[:, :, 1:] * self.test_wt, traj[:, :, 1:] * self.test_wt)
        if use_ref:
            ref_loss = self.criterion(ref_traj[:, :, 1:] * self.test_wt, traj[:, :, 1:] * self.test_wt)

        if use_ref:
            return loss.item(), batch_time, ref_loss.item()
        return loss.item(), batch_time
        
    
    
    def save_checkpoint(self, state, filename='checkpoint.pth.tar'):
        filepath = os.path.join(self.args.ckpt, filename)
        torch.save(state, filepath)
            
        
        
        
